{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 数据处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 安装类库\n",
    "# !mkdir /home/aistudio/external-libraries\n",
    "# !pip install imgaug -t /home/aistudio/external-libraries\n",
    "import sys\n",
    "sys.path.append('/home/aistudio/external-libraries')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-01-16 09:53:03,355-INFO: font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']\n",
      "2021-01-16 09:53:03,696-INFO: generated new fontManager\n",
      "Cache file /home/aistudio/.cache/paddle/dataset/cifar/cifar-100-python.tar.gz not found, downloading https://dataset.bj.bcebos.com/cifar/cifar-100-python.tar.gz \n",
      "Begin to download\n",
      "\n",
      "Download finished\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "image shape: (32, 32, 3)\n",
      "label value: cattle\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import imgaug as ia\n",
    "import imgaug.augmenters as iaa\n",
    "\n",
    "# 读取数据\n",
    "reader = paddle.batch(\n",
    "    paddle.dataset.cifar.train100(),\n",
    "    batch_size=8) # 数据集读取器\n",
    "data = next(reader()) # 读取数据\n",
    "index = 0 # 批次索引\n",
    "\n",
    "# 读取图像\n",
    "image = np.array([x[0] for x in data]).astype(np.float32) # 读取图像数据，数据类型为float32\n",
    "image = image * 255 # 从[0,1]转换到[0,255]\n",
    "image = image[index].reshape((3, 32, 32)).transpose((1, 2, 0)).astype(np.uint8) # 数据格式从CHW转换为HWC，数据类型转换为uint8\n",
    "print('image shape:', image.shape)\n",
    "\n",
    "# 图像增强\n",
    "# sometimes = lambda aug: iaa.Sometimes(0.5, aug) # 随机进行图像增强\n",
    "# seq = iaa.Sequential([\n",
    "#     sometimes(iaa.CropAndPad(px=(-4, 4))),      # 随机裁剪填充像素\n",
    "#     iaa.Fliplr(0.5)])                           # 随机进行水平翻转\n",
    "# image = seq(image=image)\n",
    "\n",
    "# 读取标签\n",
    "label = np.array([x[1] for x in data]).astype(np.int64) # 读取标签数据，数据类型为int64\n",
    "vlist = ['beaver', 'dolphin', 'otter', 'seal', 'whale',\n",
    "         'aquarium fish', 'flatfish', 'ray', 'shark', 'trout',\n",
    "         'orchids', 'poppies', 'roses', 'sunflowers', 'tulips',\n",
    "         'bottles', 'bowls', 'cans', 'cups', 'plates',\n",
    "         'apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers',\n",
    "         'clock', 'keyboard', 'lamp', 'telephone', 'television',\n",
    "         'bed', 'chair', 'couch', 'table', 'wardrobe',\n",
    "         'bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach',\n",
    "         'bear', 'leopard', 'lion', 'tiger', 'wolf',\n",
    "         'bridge', 'castle', 'house', 'road', 'skyscraper',\n",
    "         'cloud', 'forest', 'mountain', 'plain', 'sea',\n",
    "         'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo',\n",
    "         'fox', 'porcupine', 'possum', 'raccoon', 'skunk',\n",
    "         'crab', 'lobster', 'snail', 'spider', 'worm',\n",
    "         'baby', 'boy', 'girl', 'man', 'woman',\n",
    "         'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle',\n",
    "         'hamster', 'mouse', 'rabbit', 'shrew', 'squirrel',\n",
    "         'maple', 'oak', 'palm', 'pine', 'willow',\n",
    "         'bicycle', 'bus', 'motorcycle', 'pickup truck', 'train',\n",
    "         'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor'] # 标签名称列表\n",
    "vlist.sort() # 字母上升排序\n",
    "print('label value:', vlist[label[index]])\n",
    "\n",
    "# 显示图像\n",
    "image = Image.fromarray(image)   # 转换图像格式\n",
    "image.save('./work/out/img.png') # 保存读取图像\n",
    "plt.figure(figsize=(3, 3))       # 设置显示大小\n",
    "plt.imshow(image)                # 设置显示图像\n",
    "plt.show()                       # 显示图像文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_data: image shape (128, 3, 32, 32), label shape:(128, 1)\n",
      "valid_data: image shape (128, 3, 32, 32), label shape:(128, 1)\n"
     ]
    }
   ],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "import imgaug as ia\n",
    "import imgaug.augmenters as iaa\n",
    "\n",
    "# 训练数据增强\n",
    "def train_augment(images):\n",
    "    # 转换格式\n",
    "    images = images * 255 # 从[0,1]转换到[0,255]\n",
    "    images = images.transpose((0, 2, 3, 1)).astype(np.uint8) # 数据格式从BCHW转换为BHWC，数据类型转换为uint8\n",
    "    \n",
    "    # 增强图像\n",
    "    sometimes = lambda aug: iaa.Sometimes(0.5, aug) # 随机进行图像增强\n",
    "    seq = iaa.Sequential([\n",
    "        sometimes(iaa.CropAndPad(px=(-4, 4))),      # 随机裁剪填充像素\n",
    "        iaa.Fliplr(0.5)])                           # 随机进行水平翻转\n",
    "    images = seq(images=images)\n",
    "    \n",
    "    # 减去均值\n",
    "    mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, 1, -1)) # cifar数据集通道平均值\n",
    "    stdv = np.array([0.2471, 0.2435, 0.2616]).reshape((1, 1, 1, -1)) # cifar数据集通道标准差\n",
    "    \n",
    "    images = (images/255.0 - mean) / stdv # 对图像进行归一化\n",
    "    images = images.transpose((0, 3, 1, 2)).astype(np.float32) # 数据格式从BHWC转换为BCHW，数据类型转换为float32\n",
    "    \n",
    "    return images\n",
    "\n",
    "# 验证数据增强\n",
    "def valid_augment(images):\n",
    "    # 转换格式\n",
    "    images = images * 255 # 从[0,1]转换到[0,255]\n",
    "    images = images.transpose((0, 2, 3, 1)).astype(np.uint8) # 数据格式从BCHW转换为BHWC，数据类型转换为uint8\n",
    "    \n",
    "    # 减去均值\n",
    "    mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, 1, -1)) # cifar数据集通道平均值\n",
    "    stdv = np.array([0.2471, 0.2435, 0.2616]).reshape((1, 1, 1, -1)) # cifar数据集通道标准差\n",
    "    \n",
    "    images = (images/255.0 - mean) / stdv # 对图像进行归一化\n",
    "    images = images.transpose((0, 3, 1, 2)).astype(np.float32) # 数据格式从BHWC转换为BCHW，数据类型转换为float32\n",
    "    \n",
    "    return images\n",
    "\n",
    "# 读取训练数据\n",
    "train_reader = paddle.batch(\n",
    "    paddle.reader.shuffle(paddle.dataset.cifar.train100(), buf_size=50000),\n",
    "    batch_size=128) # 构造数据读取器\n",
    "train_data = next(train_reader()) # 读取训练数据\n",
    "\n",
    "train_image = np.array([x[0] for x in train_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取训练图像\n",
    "train_image = train_augment(train_image)                                                       # 训练图像增强\n",
    "train_label = np.array([x[1] for x in train_data]).reshape((-1, 1)).astype(np.int64)           # 读取训练标签\n",
    "print('train_data: image shape {}, label shape:{}'.format(train_image.shape, train_label.shape))\n",
    "\n",
    "# 读取验证数据\n",
    "valid_reader = paddle.batch(\n",
    "    paddle.dataset.cifar.test100(),\n",
    "    batch_size=128) # 构造数据读取器\n",
    "valid_data = next(valid_reader()) # 读取验证数据\n",
    "\n",
    "valid_image = np.array([x[0] for x in valid_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取验证图像\n",
    "valid_image = valid_augment(valid_image)                                                       # 验证图像增强\n",
    "valid_label = np.array([x[1] for x in valid_data]).reshape((-1, 1)).astype(np.int64)           # 读取验证标签\n",
    "print('valid_data: image shape {}, label shape:{}'.format(valid_image.shape, valid_label.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 模型设计"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import paddle.fluid as fluid\n",
    "from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, BatchNorm\n",
    "import math\n",
    "\n",
    "# 模组结构：输入维度，输出维度，滑动步长，基础长度, 队列长度\n",
    "group_arch = [(3, 128, 1, 2, 3), (512, 256, 2, 2, 3), (1024, 512, 2, 2, 3)]\n",
    "group_dim  = 2048 # 模组输出维度\n",
    "class_dim  = 100  # 类别数量维度\n",
    "\n",
    "# 卷积单元\n",
    "class ConvUnit(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, filter_size=3, stride=1, act=None):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化卷积单元，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim      - 输入维度\n",
    "            out_dim     - 输出维度\n",
    "            filter_size - 卷积大小\n",
    "            stride      - 滑动步长\n",
    "            act         - 激活函数\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(ConvUnit, self).__init__()\n",
    "        \n",
    "        # 添加卷积\n",
    "        self.conv = Conv2D(\n",
    "            num_channels=in_dim,\n",
    "            num_filters=out_dim,\n",
    "            filter_size=filter_size,\n",
    "            stride=stride,\n",
    "            padding=(filter_size-1)//2,                       # 输出特征图大小不变\n",
    "            param_attr=fluid.initializer.MSRA(uniform=False), # 使用MARA 初始权重\n",
    "            bias_attr=False,                                  # 卷积输出没有偏置项\n",
    "            act=None)\n",
    "        \n",
    "        # 添加正则\n",
    "        self.norm = BatchNorm(\n",
    "            num_channels=out_dim,\n",
    "            param_attr=fluid.initializer.Constant(1.0), # 使用常量初始化权重\n",
    "            bias_attr=fluid.initializer.Constant(0.0),  # 使用常量初始化偏置\n",
    "            act=act)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征进行卷积和正则\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "        \"\"\"\n",
    "        # 进行卷积\n",
    "        x = self.conv(x)\n",
    "        \n",
    "        # 进行正则\n",
    "        x = self.norm(x)\n",
    "        \n",
    "        return x\n",
    "\n",
    "# 投影单元\n",
    "class ProjUnit(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, filter_size=1, stride=1, act=None):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化投影单元，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim      - 输入维度\n",
    "            out_dim     - 输出维度\n",
    "            filter_size - 卷积大小\n",
    "            stride      - 滑动步长\n",
    "            act         - 激活函数\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(ProjUnit, self).__init__()\n",
    "        \n",
    "        # 添加池化\n",
    "        self.pool = Pool2D(\n",
    "            pool_size=filter_size,\n",
    "            pool_stride=stride,\n",
    "            pool_padding=0,\n",
    "            pool_type='avg')\n",
    "        \n",
    "        # 添加卷积\n",
    "        self.conv = Conv2D(\n",
    "            num_channels=in_dim,\n",
    "            num_filters=out_dim,\n",
    "            filter_size=1,\n",
    "            stride=1,\n",
    "            padding=0,\n",
    "            param_attr=fluid.initializer.MSRA(uniform=False), # 使用MARA 初始权重\n",
    "            bias_attr=False,                                  # 卷积输出没有偏置项\n",
    "            act=None)\n",
    "        \n",
    "        # 添加正则\n",
    "        self.norm = BatchNorm(\n",
    "            num_channels=out_dim,\n",
    "            param_attr=fluid.initializer.Constant(1.0), # 使用常量初始化权重\n",
    "            bias_attr=fluid.initializer.Constant(0.0),  # 使用常量初始化偏置\n",
    "            act=act)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征进行池化卷积和正则\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "        \"\"\"\n",
    "        # 进行池化\n",
    "        x = self.pool(x)\n",
    "        \n",
    "        # 进行卷积\n",
    "        x = self.conv(x)\n",
    "        \n",
    "        # 进行正则\n",
    "        x = self.norm(x)\n",
    "        \n",
    "        return x\n",
    "\n",
    "# 队列结构\n",
    "class SSRQueue(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, stride=1, queues=2, act=None):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化队列结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim  - 输入维度\n",
    "            out_dim - 输出维度\n",
    "            stride  - 滑动步长，1保持不变，2下采样\n",
    "            queues  - 队列长度，分割尺度为2^(n-1)\n",
    "            act     - 激活函数\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRQueue, self).__init__()\n",
    "        \n",
    "        # 添加队列变量\n",
    "        self.queues = queues # 队列长度\n",
    "        self.split_list = [] # 分割列表\n",
    "        \n",
    "        # 添加队列列表\n",
    "        self.queue_list = [] # 队列列表\n",
    "        for i in range(queues):\n",
    "            # 添加队列项目\n",
    "            queue_item = self.add_sublayer( # 构造队列项目\n",
    "                'queue_' + str(i),\n",
    "                ConvUnit(\n",
    "                    in_dim=(in_dim if i==0 else out_dim), # 每组队列项目除第一个外，in_dim=out_dim\n",
    "                    out_dim=out_dim,\n",
    "                    filter_size=3,\n",
    "                    stride=(stride if i==0 else 1), # 每组队列项目除第一块外，stride=1\n",
    "                    act=act))\n",
    "            self.queue_list.append(queue_item) # 添加队列项目\n",
    "            \n",
    "            # 计算输出维度\n",
    "            if i < (queues-1): # 如果不是最后一项\n",
    "                out_dim = out_dim//2 # 输出维度减半\n",
    "                self.split_list.append(out_dim) # 添加分割列表\n",
    "            \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "        \"\"\"\n",
    "        # 提取特征\n",
    "        x_list = [] # 队列输出列表\n",
    "        for i, queue_item in enumerate(self.queue_list):\n",
    "            if i < (self.queues-1): # 如果不是最后一项\n",
    "                x = queue_item(x) # 提取队列特征\n",
    "                x_item, x = fluid.layers.split(input=x, num_or_sections=[-1, self.split_list[i]], dim=1)\n",
    "                x_list.append(x_item) # 添加输出列表\n",
    "            else: # 否则不对特征分割\n",
    "                x = queue_item(x) # 提取队列特征\n",
    "                x_list.append(x) # 添加输出列表\n",
    "        \n",
    "        # 联结特征\n",
    "        x = fluid.layers.concat(input=x_list, axis=1) # 队列输出列表按通道维进行特征联结\n",
    "        \n",
    "        return x\n",
    "    \n",
    "# 基础结构\n",
    "class SSRBasic(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, stride=1, queues=1, is_pass=True):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化基础结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim  - 输入维度\n",
    "            out_dim - 输出维度\n",
    "            stride  - 滑动步长\n",
    "            queues  - 队列长度\n",
    "            is_pass - 是否直连\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRBasic, self).__init__()\n",
    "        \n",
    "        # 是否直连标识\n",
    "        self.is_pass = is_pass\n",
    "        \n",
    "        # 添加投影路径\n",
    "        self.proj = ProjUnit(in_dim=in_dim, out_dim=out_dim*4, filter_size=stride, stride=stride, act=None)\n",
    "        \n",
    "        # 添加卷积路径\n",
    "        self.con1 = ConvUnit(in_dim=in_dim, out_dim=out_dim, filter_size=1, stride=1, act='relu')\n",
    "        \n",
    "        if queues==1:\n",
    "            self.con2 = ConvUnit(in_dim=out_dim, out_dim=out_dim, filter_size=3, stride=stride, act='relu')\n",
    "        else:\n",
    "            self.con2 = SSRQueue(in_dim=out_dim, out_dim=out_dim, stride=stride, queues=queues, act='relu')\n",
    "        \n",
    "        self.con3 = ConvUnit(in_dim=out_dim, out_dim=out_dim*4, filter_size=1, stride=1, act=None)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "            y - 输出特征\n",
    "        \"\"\"\n",
    "        # 直连路径\n",
    "        if self.is_pass: # 是否直连\n",
    "            x_pass = x\n",
    "        else:            # 否则投影\n",
    "            x_pass = self.proj(x)\n",
    "        \n",
    "        # 卷积路径\n",
    "        x_con1 = self.con1(x)      # 特征降维\n",
    "        x_con2 = self.con2(x_con1) # 特征提取\n",
    "        x_con3 = self.con3(x_con2) # 特征升维\n",
    "        \n",
    "        # 输出特征\n",
    "        x = fluid.layers.elementwise_add(x=x_pass, y=x_con3, act='relu') # 直连路径与卷积路径进行特征相加\n",
    "        y = x\n",
    "        \n",
    "        return x, y\n",
    "    \n",
    "# 模块结构\n",
    "class SSRBlock(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, stride=1, basics=1, queues=1):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化模块结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim  - 输入维度\n",
    "            out_dim - 输出维度\n",
    "            stride  - 滑动步长\n",
    "            basics  - 基础长度\n",
    "            queues  - 队列长度\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRBlock, self).__init__()\n",
    "        \n",
    "        # 添加模块列表\n",
    "        self.block_list = [] # 模块列表\n",
    "        for i in range(basics):\n",
    "            block_item = self.add_sublayer( # 构造模块项目\n",
    "                'block_' + str(i),\n",
    "                SSRBasic(\n",
    "                    in_dim=(in_dim if i==0 else out_dim*4), # 每组模块项目除第一块外，输入维度=输出维度\n",
    "                    out_dim=out_dim,\n",
    "                    stride=(stride if i==0 else 1), # 每组模块项目除第一块外，stride=1\n",
    "                    queues=queues,\n",
    "                    is_pass=(False if i==0 else True))) # 每组模块项目除第一块外，is_pass=True\n",
    "            self.block_list.append(block_item) # 添加模块项目\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x      - 输入特征\n",
    "        输出:\n",
    "            x      - 输出特征\n",
    "            y_list - 输出特征列表\n",
    "        \"\"\"\n",
    "        y_list = [] # 模块输出列表\n",
    "        for block_item in self.block_list:\n",
    "            x, y_item = block_item(x) # 提取模块特征\n",
    "            y_list.append(y_item) # 添加输出列表\n",
    "            \n",
    "        return x, y_list\n",
    "\n",
    "# 模组结构\n",
    "class SSRGroup(fluid.dygraph.Layer):\n",
    "    def __init__(self):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化模组结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRGroup, self).__init__()\n",
    "        \n",
    "        # 添加模组列表\n",
    "        self.group_list = [] # 模组列表\n",
    "        for i, block_arch in enumerate(group_arch):\n",
    "            group_item = self.add_sublayer( # 构造模组项目\n",
    "                'group_' + str(i),\n",
    "                SSRBlock(\n",
    "                    in_dim=block_arch[0],\n",
    "                    out_dim=block_arch[1],\n",
    "                    stride=block_arch[2],\n",
    "                    basics=block_arch[3],\n",
    "                    queues=block_arch[4]))\n",
    "            self.group_list.append(group_item) # 添加模组项目\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x      - 输入特征\n",
    "        输出:\n",
    "            x      - 输出特征\n",
    "            y_list - 输出特征列表\n",
    "        \"\"\"\n",
    "        y_list = [] # 模组输出列表\n",
    "        for group_item in self.group_list:\n",
    "            x, y_item = group_item(x) # 提取模组特征\n",
    "            y_list.append(y_item) # 添加输出列表\n",
    "            \n",
    "        return x, y_list\n",
    "        \n",
    "# 分割网络\n",
    "class SSRNet(fluid.dygraph.Layer):\n",
    "    def __init__(self):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化分割网络，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRNet, self).__init__()\n",
    "        \n",
    "        # 添加模组结构\n",
    "        self.backbone = SSRGroup() # 输出：N*C*H*W\n",
    "        \n",
    "        # 添加全连接层\n",
    "        self.pool = Pool2D(global_pooling=True, pool_type='avg') # 输出：N*C*1*1\n",
    "        \n",
    "        stdv = 1.0/(math.sqrt(group_dim)*1.0)                    # 设置均匀分布权重方差\n",
    "        self.fc = Linear(                                        # 输出：=N*10\n",
    "            input_dim=group_dim,\n",
    "            output_dim=class_dim,\n",
    "            param_attr=fluid.initializer.Uniform(-stdv, stdv),   # 使用均匀分布初始权重\n",
    "            bias_attr=fluid.initializer.Constant(0.0),           # 使用常量数值初始偏置\n",
    "            act='softmax')\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入图像进行分类\n",
    "        输入:\n",
    "            x - 输入图像\n",
    "        输出:\n",
    "            x - 预测结果\n",
    "        \"\"\"\n",
    "        # 提取特征\n",
    "        x, y_list = self.backbone(x)\n",
    "        \n",
    "        # 进行预测\n",
    "        x = self.pool(x)\n",
    "        x = fluid.layers.reshape(x, [x.shape[0], -1])\n",
    "        x = self.fc(x)\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tatol param: 21322980\n",
      "infer shape: [1, 100]\n"
     ]
    }
   ],
   "source": [
    "import paddle.fluid as fluid\n",
    "from paddle.fluid.dygraph.base import to_variable\n",
    "import numpy as np\n",
    "\n",
    "with fluid.dygraph.guard():\n",
    "    # 输入数据\n",
    "    x = np.random.randn(1, 3, 32, 32).astype(np.float32)\n",
    "    x = to_variable(x)\n",
    "    \n",
    "    # 进行预测\n",
    "    backbone = SSRNet() # 设置网络\n",
    "    \n",
    "    infer = backbone(x) # 进行预测\n",
    "    \n",
    "    # 显示结果\n",
    "    parameters = 0\n",
    "    for p in backbone.parameters():\n",
    "        parameters += np.prod(p.shape) # 统计参数\n",
    "    \n",
    "    print('tatol param:', parameters)\n",
    "    print('infer shape:', infer.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD8CAYAAABw1c+bAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3Xd8VeX9wPHPk9ybTRJGmGEv2QgRQRQITlxUxVaKiNqWVlvrqD9Xq7WOilrrKFpFixW1IFIVAdGKoqAoEjYyQ1gJKwlkrzue3x/PSe7N3iQn+b5fr/u6Z91znnNP8j3PedZVWmuEEEK0HgFNnQAhhBBnlgR+IYRoZSTwCyFEKyOBXwghWhkJ/EII0cpI4BdCiFZGAr8QQrQy1QZ+pVSIUuoHpdRWpdSPSqm/VLBNsFLqPaVUolJqvVKqV2MkVgghRP3VJMdfCEzWWo8ARgKXKaXGltnmF8BprXU/4Hng6YZNphBCiIbiqG4Dbbr25lizTutVtrvvVOBRa3oJMFcppXQV3YI7dOige/XqVdv0CiFEq7Zx48Y0rXVMffZRbeAHUEoFAhuBfsDLWuv1ZTbpBhwB0Fq7lVKZQHsgrbJ99urVi4SEhDolWgghWiul1KH67qNGlbtaa4/WeiQQC4xRSg2ty8GUUrOVUglKqYTU1NS67EIIIUQ91apVj9Y6A1gNXFZmVQrQHUAp5QCigPQKPj9Pax2ntY6LianXk4oQQog6qkmrnhilVLQ1HQpcDOwus9nHwCxrehrwZVXl+0IIIZpOTcr4uwBvWeX8AcBirfVypdRjQILW+mPgX8DbSqlE4BRwQ10S43K5SE5OpqCgoC4fF35CQkKIjY3F6XQ2dVKEEM1MTVr1bAPOrmD5I37TBcD19U1McnIybdq0oVevXiil6ru7VktrTXp6OsnJyfTu3bupkyOEaGaaVc/dgoIC2rdvL0G/npRStG/fXp6chBAValaBH5Cg30DkexRCVKbZBf7qFLg8HM8swOXxNnVShBDClmwZ+E9mF+DxNnyjofT0dEaOHMnIkSPp3Lkz3bp1K5kvKiqq0T5uueUW9uzZU+NjvvHGG9x11111TbIQQtRajXruthbt27dny5YtADz66KNERERw7733ltpGa43WmoCAiu+Zb775ZqOnUwgh6sN2Of7ikusz2UkgMTGRwYMHM2PGDIYMGcKxY8eYPXs2cXFxDBkyhMcee6xk2/PPP58tW7bgdruJjo7mgQceYMSIEYwbN46TJ09WeZwDBw4QHx/P8OHDufjii0lOTgZg0aJFDB06lBEjRhAfHw/A9u3bOeeccxg5ciTDhw8nKSmp8b4AIUSL0mxz/H9Z9iM7j2aVW+7xagpcHkKDAgmoZQXm4K6R/PmqIXVKz+7du1mwYAFxcXEAzJkzh3bt2uF2u4mPj2fatGkMHjy41GcyMzOZOHEic+bM4Z577mH+/Pk88MADlR7j9ttv55e//CUzZsxg3rx53HXXXSxZsoS//OUvfPXVV3Tq1ImMjAwAXnnlFe69915+9rOfUVhYiPSXE0LUlO1y/E2lb9++JUEfYOHChYwaNYpRo0axa9cudu7cWe4zoaGhTJkyBYDRo0dz8ODBKo+xfv16brjB9H276aabWLt2LQDjx4/npptu4o033sDrNZXa5513Hk888QTPPPMMR44cISQkpCFOUwjRCjTbHH9lOfPMfBeH0nPp3zGC0KAzl/zw8PCS6X379vHiiy/yww8/EB0dzY033lhhm/mgoKCS6cDAQNxud52O/frrr7N+/XqWL1/OqFGj2Lx5MzNnzmTcuHGsWLGCyy67jPnz5zNhwoQ67V8I0bpIjr8OsrKyaNOmDZGRkRw7dozPPvusQfY7duxYFi9eDMA777xTEsiTkpIYO3Ysjz/+OG3btiUlJYWkpCT69evHnXfeyZVXXsm2bdsaJA1CiJav2eb4K9MUlbtljRo1isGDB3PWWWfRs2dPxo8f3yD7ffnll7n11lt56qmn6NSpU0kLobvvvpsDBw6gteaSSy5h6NChPPHEEyxcuBCn00nXrl159NFHGyQNQoiWTzVVpWBcXJwu+0Msu3btYtCgQVV+LivfxcH0XPp1jCDsDBb12FFNvk8hhL0opTZqreOq37JyUtQjhBCtjAR+IYRoZewb+KXZuhBC1In9Ar9VuytxXwgh6sZ2gV8GGxZCiPqxXeAXQghRPxL4/cTHx5frjPXCCy9w2223Vfm5iIgIAI4ePcq0adMq3GbSpEmUbb5a1XIhhGgsEvj9TJ8+nUWLFpVatmjRIqZPn16jz3ft2pUlS5Y0RtKEEKLB2C7wN2bP3WnTprFixYqSH105ePAgR48e5YILLiAnJ4cLL7yQUaNGMWzYMJYuXVru8wcPHmTo0KEA5Ofnc8MNNzBo0CCuueYa8vPzqz3+woULGTZsGEOHDuX+++8HwOPxcPPNNzN06FCGDRvG888/D8BLL73E4MGDGT58eMnAbkIIURPNt+vrygfg+PZyi0O9Xvq4vIQGBUJtf1e28zCYMqfS1e3atWPMmDGsXLmSqVOnsmjRIn7605+ilCIkJIQPP/yQyMhI0tLSGDt2LFdffXWlv237z3/+k7CwMHbt2sW2bdsYNWpUlUk7evQo999/Pxs3bqRt27ZccsklfPTRR3Tv3p2UlBR27NgBUDIs85w5czhw4ADBwcEly4QQoiZsl+Mv0UjtOf2Le/yLebTWPPTQQwwfPpyLLrqIlJQUTpw4Uel+1qxZw4033gjA8OHDGT58eJXH3bBhA5MmTSImJgaHw8GMGTNYs2YNffr0ISkpiTvuuINPP/2UyMjIkn3OmDGDd955B4ej+d6/hRDNT/ONGJXkzPML3CSl5dCnQzgRIc4GP+zUqVO5++672bRpE3l5eYwePRqAd999l9TUVDZu3IjT6aRXr14VDsXc0Nq2bcvWrVv57LPPePXVV1m8eDHz589nxYoVrFmzhmXLlvHkk0+yfft2uQEIIWrEfjn+Rm7IHxERQXx8PLfeemupSt3MzEw6duyI0+lk9erVHDp0qMr9TJgwgf/85z8A7Nixo9phk8eMGcPXX39NWloaHo+HhQsXMnHiRNLS0vB6vVx33XU88cQTbNq0Ca/Xy5EjR4iPj+fpp58mMzOTnJyc+p+8EKJVsF0W8UwMyzx9+nSuueaaUi18ZsyYwVVXXcWwYcOIi4vjrLPOqnIft912G7fccguDBg1i0KBBJU8OlenSpQtz5swhPj4erTVXXHEFU6dOZevWrdxyyy0lv7z11FNP4fF4uPHGG8nMzERrze9//3uio6Prf+JCiFbBdsMy5xa62Z+aQ+8O4bRphKKelkSGZRai5ZFhmYUQQtSaBH4hhGhlml3gb6qip5ZGvkchRGWaVeAPCQkhPT29RkFLwlrltNakp6cTEhLS1EkRQjRD1bbqUUp1BxYAnTDxdp7W+sUy20wClgIHrEUfaK0fq21iYmNjSU5OJjU1tdJtitxeTmYX4jkVRIgzsLaHaDVCQkKIjY1t6mQIIZqhmjTndAN/0FpvUkq1ATYqpT7XWu8ss91arfWV9UmM0+mkd+/eVW6z5UgGv3r3W+bfHMfkszrV53BCCNEqVVvUo7U+prXeZE1nA7uAbo2dsMqUtOOXsh4hhKiTWpXxK6V6AWcD6ytYPU4ptVUptVIpNaQB0lahAGtQNAn8QghRNzXuuauUigD+C9yltc4qs3oT0FNrnaOUuhz4COhfwT5mA7MBevToUacEFw+G6ZXIL4QQdVKjHL9SyokJ+u9qrT8ou15rnaW1zrGmPwGcSqkOFWw3T2sdp7WOi4mJqVfCJewLIUTdVBv4lRlw/l/ALq313yvZprO1HUqpMdZ+0xsyob5jmXfJ8AshRN3UpKhnPDAT2K6U2mItewjoAaC1fhWYBtymlHID+cANupF6EKkzMkybEEK0XNUGfq31N1QzGLLWei4wt6ESVRXJ8QshRP00q567NVES+Js2GUIIYVv2C/xIc04hhKgP+wX+khy/RH4hhKgL2wX+ACnjF0KIerFd4C+uZ5YOXEIIUTe2C/yqkX9sXQghWjr7BX7rXTL8QghRN/YL/MWDtEnlrhBC1In9Ar/1Ljl+IYSoG/sFfmnVI4QQ9WK/wF/cgauJ0yGEEHZlv8BfkuOX0C+EEHVh38DftMkQQgjbsmHgLx6rR0K/EELUhf0Cv/UucV8IIerGfoFfinqEEKJe7Bf4ZVhmIYSoF/sFfhmWWQgh6sV+gd96lxy/EELUje0CP1LGL4QQ9WK7wK+QMRuEEKI+bBf4AyTHL4QQ9WK7wF/cgcvrldAvhBB1Yb/Ab71L2BdCiLqxXeAPKM7xS+QXQog6sV3gV1aKZaweIYSoG9sF/sCSHL8EfiGEqAvbBf7ioh6Pt4kTIoQQNmW/wG+lWHL8QghRN/YL/DIevxBC1IttA78U9QghRN1UG/iVUt2VUquVUjuVUj8qpe6sYBullHpJKZWolNqmlBrVOMn19dyVoh4hhKgbRw22cQN/0FpvUkq1ATYqpT7XWu/022YK0N96nQv803pvcEoplJKiHiGEqKtqc/xa62Na603WdDawC+hWZrOpwAJtfA9EK6W6NHhqLQFK4ZHAL4QQdVKrMn6lVC/gbGB9mVXdgCN+88mUvzk0mEClpOeuEELUUY0Dv1IqAvgvcJfWOqsuB1NKzVZKJSilElJTU+uyCwAuD/iOCYf+UefPCyFEa1ajwK+UcmKC/rta6w8q2CQF6O43H2stK0VrPU9rHae1jouJialLegEYrfYw8uTSOn9eCCFas5q06lHAv4BdWuu/V7LZx8BNVuuesUCm1vpYA6azFLdy4tBFjbV7IYRo0WrSqmc8MBPYrpTaYi17COgBoLV+FfgEuBxIBPKAWxo+qT4u5SDQ627MQwghRItVbeDXWn+Dbxj8yrbRwG8bKlHVceMkAA94PRAQeKYOK4QQLYLteu4CuJTTTLgLmzYhQghhQ7YM/B5lPah4pJxfCCFqy5aB30WQmZDAL4QQtWbLwC85fiGEqDtbBn63lPELIUSd2TLwu5QU9QghRF3ZMvBLUY8QQtSdLQO/r6hHAr8QQtSWLQO/R4p6hBCizuwZ+AOsHL9HKneFEKK2bBn4pahHCCHqzpaB35fjl8AvhBC1ZcvArwMl8AshRF3ZNPAHmwnpwCWEELVmy8BPoLTqEUKIurJl4FfFOX4J/EIIUWv2DPwOKeMXQoi6smnglzJ+IYSoK3sHfsnxCyFErdky8Ac5AnHhkMAvhBB1YMvA7wxUFOGQnrtCCFEHNg38Abhwylg9QghRB7YN/EVS1COEEHViy8Af4gykUEtRjxBC1IUtA394UCBF2oFXmnMKIUSt2TLwhwU7KMKBxyWBXwghasuWgT8iOJAinHhcBU2dFCGEsB1bBv6wIAcuHHhdUsYvhBC1ZcvA3zYsiCLtwFUkOX4hhKgtWwb+Ed2jcCsHufn5TZ0UIYSwHVsG/jYhToKCQ3EVSuAXQojaqjbwK6XmK6VOKqV2VLJ+klIqUym1xXo90vDJLM8bEIRDu87EoYQQokVx1GCbfwNzgQVVbLNWa31lg6SohjzKgUO7z+QhhRCiRag2x6+1XgOcOgNpqRV3gFNy/EIIUQcNVcY/Tim1VSm1Uik1pIH2WSWPcuLQ0pxTCCFqqyEC/yagp9Z6BPAP4KPKNlRKzVZKJSilElJTU+t10ORsL4HeIjYeanYPI0II0azVO/BrrbO01jnW9CeAUynVoZJt52mt47TWcTExMfU6boY3lAhVwLp9J+u1HyGEaG3qHfiVUp2VUsqaHmPtM72++61Olg4HIMid29iHEkKIFqXaVj1KqYXAJKCDUioZ+DPgBNBavwpMA25TSrmBfOAGrbVutBRbsgkFIMid3diHEkKIFqXawK+1nl7N+rmY5p5nlC/Hn3OmDy2EELZmy567AFmEARAsOX4hhKgV+wZ+bQV+j+T4hRCiNuwb+Ety/BL4hRCiNuwb+K0yfinqEUKI2rFt4M+xWvUEeyTwCyFEbdg28HsIJEeHSKseIYSoJdsGfjDl/O68jKZOhhBC2Iq9A78O51R6/cb8EUKI1sbWgT+bUNqQ19TJEEIIW7F14M/S4UQqGatHCCFqw96BnzAiJccvhBC1YuvAn63DiFY50PhjwgkhRIth68C/XfcmSuVB0uqmTooQQtiGrQP/Us94jul2sOa5pk6KEELYhq0DfxFOFronw6FvIF/a8wshRE3YOvADbNN9zMSJHU2bECGEsAnbBv4OEcEA/Kh7mQXHtjVdYoQQwkZsG/iX/m48kSEOAtp0gohOcFwCvxBC1IRtA3+36FCuGN6VE1mFuDsOhePbmzpJQghhC7YN/ACBVuo/ONYBTu6CgqymTZAQQtiAvQO/UgC8nzEQtKfltedP2QgH1jR1KoQQLYytA39xf91Nuj+ERMPez87MgY9sAHdh4x/n0wfh49/XbNvD38MXjzVueoQQLYKtA7/ba0K/h0DofzHs/RTyTzfuQbOOwb8uhk0Lyq8rzAavt2GO43HBsa1w+iAU1WAguq+egrXPQUFmwxy/OolfNP53LYRoFPYO/B6/IDvm1ybwLp7VcGP3nEoqv6/TBwBtgnKpxBTCC8NN8PW3bxWcOlD7Y5/cBe4Cc6zUPaXX7VkJX83xzeec9BUJpSWW3jZtHyz8Obx6AWSm1OzYafvMuaTurXj98e3wzrVmG+k4J4Tt2Dvwe31B+dYvgPiH4MDXkHGo/jvftwpeOht+/MC3bP9qSLcC64kfS29/fDvkn4If5oG7yCzTGhbPhFWPwp5P4eTumh//6CbfdGqZz339jHm58s38jg9AWzfB9H2+7bKOwYKpcHgdpO+H92eB11P9sTf+23yHe1eaeY/L3GjyTpn5pK/Ne2EWbHm3/Oe1hhX3wuZ3qj+WEOKMs3Xg9/gF/i93n4Q+8WbmyAbIPl7xhwpr+Bu9W/9j3k/8aIpPkhPg7Z/AmmfN8mNb4aPf+nL+KRvNe+5J2L3MTOecAFceJH1lbgD/+xN8/yokrip9rA9+bdYVp68wG1I2QXAUBAbDyZ2+bbNPmJuC9phjn9wNXz8NsedAgAPS/HLpP7xm0jBrGVz+DCRvMDfGyrjy4YfXYfv7Zv7Qd9b7t6YoqTjIH/ga2veDyFg4urn8fk4fhA2vw9Lfwhbre3QX1uymI4RodLYO/P45fgA6DQFHCHxyL7w4Ak4f8gVkgE/+D57qZoJnZbQ2xSb7Pjfzh9bBnB6waIaZzzhsbeeBLe+YQA7mxhDRGYIjTUVr0tdw2AqcBRngKYKDa+Gzh+CbF8yNKT/DFANtWwRbFppjvTQSXr8Qdi6FPhOhwwBT7APw1dPw3ABfWt+bCa+ca3LeV70EbXuXDvz7V0PsGOg8DIZOM2nb9n7l5/710+a7yzlhgvrh70ydRXFwT/ra5P4PrYPeE6HLCDi6xXxnm9+F3Z+Y7Q9/79vnvs9hwxvwREffTTNtnymeAvNUkPiFb/u8U6XrDg5+A3v/55vPzzBPMmVte99cIxmiW4hq2Trwezxl/skDnSYYFWSY8vEXh8Prk02Q2b/aFMNA6eIQMDn60wfhyydM0H/rKhPggqN8wTvH7wnCEeqb3rvS5JSTN0BsHET3NMVB706DZXf5tgtwmDRpj8nNvzkFPvyNLxedlwbv/hRUIKTtMedwwR/Mzay4c9qmt8x7h4EQ3tE8XfS/BH71JXQaDO37wq5l5ri5aeaJoM8k8xlnCAy62qwvzDbLso7C6r9CUZ65Aa2ba24Qs5abYrOCDDi22Rf4D60z00U50Gs8dB1pznX9q7D0dlg03dw8jnxvvrv+l5ob7yf3mc8nfW0C89w4+Ft/SJhvngoWzzJpAVg4HRbd6PveVj0KH//OF9A/udcUX7kLS1dkb34bdi8vXyzmypfOfUKUYevAP/3cHuUXdosz7+EdfctSd5cuXslM9k0f+s7k6F8cYXKkq580y2//Dkb8zLedCjRBHUxOfOAVMO53Jnf63EBT6dt3MrTtaXK8niITOAEGXQWXPAmBQaACwJVrKo4TPzfl6R2HmO2KsuHyZ2H8XTD2tyawdj3b5MCPbYOsFJj8MPz6a+g2ynxm8sPmZgfQebh53/gmrLwP0L7ADxB3qwnaXz5hAunS35lAvektU7zjdcPFj0HvC0wrqdB2Jigf+s40l3XlwvrXfN9zl5HmGJ8+AL0uMIH+h3nm5tl9jLlpZRwyN7vwGHPOp5J86Vl+N0R2M8dd+juT20/+wYy2mnHYpPHkLnP+Xz5hjn1grbkxfnwHzJtkKqwPfmtuvFC+Se/Hv4fXJlb9lFfWsa3m6aH4qUSIFsbR1Amoj4kDYjircxt2Hzc5WK016txfQ7veZvyej24zgS5tr8l5dh5mcn+ZR3w7SbSKdIb/DLa9B0fWm4DXtpcpZgEYMAV+8grsXmFyn+36wGVPmWaWez+D6B5w7q9N7jt9vynXL9amC/zMquQMiTKVsEtvN/NeN+SmwrT58NHt5ilj4BQYfLXv813PNu/fv2Lee08AZyiMmW0Cb5fhvm3H3wkDL4P3b4Ed/4V2fX03CIDY0XDOL00A7TAA9n9hin++exmCwk2wjupmto3oCDctNU8/BRlwwb3w7Ytmv2EdzDkHhZtto3vAdW+Y7/mtq0wl99jbIbSt79hnz4Rv/u4r1pm13DyxdBpq6hCW3w0LrvZVUv/4IQy5xlw/gLV/MzdfbdUTbF9ipt++xtwIwNxU934K51tPWomrYPti33U+23qSSEuE7+bClKfBEVz6jyrnJLw2wUz3mQRjfkWjO7AWwjtAx0GNfywhsHmOHyDY4TsFt1ebHPeYX5ng+WAyOMNNrvHoFuh5vglapXL866DbaLh2HnQ/1yzrPAyUgg79zXyPsRDWziwHiOpu3oPC4Y4EmPkBDLjUfKZtz9IJjPabHzkdRv7cBMTYc0wFafexJrd85QtwzaumuMpf52Em4G1daIqYuow0y/tdCPEPlvkyIsy5XPiI+dzMD8rv79zfABpW/cUE/WtfNzfC1N3mycRfl+Ew80PzJDH0Ougbbz7bbZQ51/AOcM9uuGMztOlszmPyw/DTt80NpjiQxZxlvkMw9SLOMOgxzuwzZqB5Ehk5w1ck0/Vs8+M6CfPNfIB1Dtqvcrh4Os2vqeuYX5uiuf89bG7Ky+82N7g2XWDrIlMPA7DuJfNUVFGv6D2f+KaTvjLvHnf57bwe2PAvX0unUutq0ZdDa1hyq0mzEGdItTl+pdR84ErgpNZ6aAXrFfAicDmQB9ystd5UdrvGoqxhGwAK3V6cgQH+K03w3vkxuPNNGfzh70zxwOlDkPAvM3/eHWb7bqNNjr84wMeOgRE/h2HTzHzn4RD/Rxh6beUJKg70znBTNFL2RqAUXDMPImJMQHIEm2X9L6p4f0FhlPRRHvlzcARV/6UMm+ZLc1nt+5qnmdMHYfBU84Rw/b9NJfXQCj7TbRT8Zq2ZHvwT2Pc/8z0Vi+xS+twm3Ot3rP6meKv7uSb4gylG6TkeAsv86V36V1Pf0Wko/HSBycl/87xZd8XfTJn+urmQl25uZgUZpiLfXWCuX0EmXPokeApNYD/ygykuumWluWluWgBvXAi/WGUqzsFUsn/3Mlz/pu/p5MgG88Q38HLTOivhTdOD+sKHTRqSE+DSJ0x9xYp7TCbioj+bzxblwbvXm7+hYdfDFc9Z18/PgbXmia0w29z0wDz5lO0XIkQjqklRz7+BuUAFXVUBmAL0t17nAv+03s+IApcvF1jo8hARXOaUOgyAY1vMdOw5EBVrOkDNm+hrPdLjPPNeHNCKA39QGFzzT9++AgJg4n1VJ6g40HcZbvbT7+Ly2wy4pAZn5ueiv5gy7MvmVL9tdZSCfheZljb9LzXLhlxjXtUZdKUpghk8tWbHcgTBjR+Ym69/ncvIn5ffNjQa/rDHVIKHdzBPYK9PNutG32zew9qbSuB9/zP1KPEPmfeLHzfnBXDZ06Yl0JHvYcL/Qc/zzFOeI9TcAN6fZW4a4TGmLgHgu1fMMftONnUM3ceYp5st78Dyu0xF9WcPWd9fAIREmhsLmKK69P3m+LtXmH0OucYcqyATvFbxXdytpkXSW1f6zrn3RFMkBib4Zx83T05CNLJqA7/Weo1SqlcVm0wFFmitNfC9UipaKdVFa11Bm7uGl13gewwvdFfwiB1hBZyRN5qgHBVrigncRaYY4/gOU5EJ5j3uVlNsU1fRVoVz+36morYhjK/heD01NWK6CZgDLqvd50Ki4MYltftM7wt807OWm6BZXBldln/Q6zYaJj0I+J7oSp5iomJNQB9/p3n5cwSZJ4b9X8K5t5llMQNMP4agcFPPMPAKc/Nddie06Qprnim9jxE3wFlXwoV/Nn8/Ay83TwwDLjUtmDb+27ftga9NX4XioqfeE80TVKe/wZePm2X7/mcqtdt0NfNn32j+/rYvNkWIxY5tM9+B12v6fUR2M+kWooE1ROVuN8CvtpRka9kZCfz5fjn+oooC/6hZprL1kifMvNN69B44xeTw+k72bRsSBVc+X78EBYXDxAdMrrq5io2D274988f1vwnUxKQHKl5eVVEbmLqFiipK4x8yAb3bKFO23nuieRpc8gtTnLTxLTPf4zzT/PWCe3yfLa4wPvc3sPU9sy7npOmoBqZuI+OwVYcCnH+3qazvOc60mFo3F9Dm6ePqueZpc9cy8wTlCDVFkR/ONsViRbmmzweYp7yAAFOZXZBh6k6EqKcz2qpHKTUbmA3Qo0cFTTHrwOvXYafCHH/MgNLBvOd5Jtd33u8a5PgVKlvpKpqHQKdp2QSmaKZdb/PqdxEEtzFPhce2QvdzKt9Hx0FwX5IpBtzwhlnWpqvpc+FX30RAoO/voM8kkwH57y9N4FbK5PR/8rKp2B12nSk6Sttr+lPkpvr2s2up6Zi38n4T+PtfCtHdG/JbEa1QQwT+FMD/LzHWWlaO1noeMA8gLi6uQbpYer3+gb8GQwL0vxj+eMLk6IQAE/TBFBNVFfSLFVfYdhxs3vtdWDroV6T7GLhrW+mexUOvg15W89y8dNO8t31fs83pg6YX9/s3l27GuvY5UzmBt78SAAAbA0lEQVSeecT3RJK6F7KPlu6zIUQVGiLwfwz8Tim1CFOpm3mmyvcBAgJKt+qpEQn6oiF0Hm5eFVVWV6bsDSIixrwHR5Tepm0v37wjxNwcBk81LdE2LTA3Aq/HFAN9+5IpzvzDHl+dgdamqKm4Fdj2JaajXZcRMOUZc4MRrVZNmnMuBCYBHZRSycCfASeA1vpV4BNMU85ETHPOWxorsRWmz2+60NVAY+ELURPBEb6mrg1NKdO3w+MyPaC9btMX4uROX1HQaqveKjwGCopMXcLom80w3mueNRXPV881/TO+ed40rU1OMD2ZZ33s69yXcdh0dht5o7lRaG2eKKK6V/8kI2ypJq16plezXgO/bbAU1VLpdvwy+qNoQeIqyEPdvMLcBLKPm5ZC3UaZBguvTzZ9Eja9DSesjnBte5vWQVE9IPOwuZH0uxDevMKMJTXpQfP0kLbPVC4nfgHT3jStnNY8ayrC+11kWlM5w0yHt57jy/cl8bh8HQULskzRmdwwmjVbD9kA4FfSU/OiHiHsKiDQvNr2LN058JLHzaB/pw/AFX83Pbw7DzVDbCS8aQLxsOvNU8rMD2H+pWbAuw4DTPPSiE7mCeKVsXBqv2ldlLjKDHz39TOmUnzXMhhyrWkdtX+12f+pJDPq7aQHzVAmS24xTyZXzzXFU0U55sZTtsOeP61NP5X0RF+Hv+ju5oYCZtwrlKlbKa4j8b+xaG1eXpdpERXWzjSJ1V6znafIpMWf12N6XRcXtVVHa98x/acLMk0P+Jre6Lxek86yQ4WcYUo30TC2cXFxOiEhod77OefJVaRmm9+/ff5nI7jm7Nh671MIW3IX+gJfdY5vN0NpT3rAdJ4D81sMq/4C586GSQ+ZzmrJP5gf1Tmx3dxMijtD+guOgkJrpNSYQaaYyOMyvajBdLyL7mGG4kjdY0ZMjehkOs2dPmiW5foNiBfgNKPNpu0zQVt7AasVVv5p0wfCEWSWtelsWkIpZdJbmGWazBZ3zgwMMvUh7fqY4wQGmeO68s2YUtE9TDD2FJmm2EW5Zl1wG1+fnIIMc1PqMtI8bZ340dyYnOHmewmJNjfh8I7m84EO0zosqofZjzvf9OouyjXnecEfKm+qXANKqY1a67g674AWEPjHPLmKk1bgn3PtMG4Y0zDNRIVolfxzs8U8bpMj7z7GDEdxfIcZljt9vyni6THOPBVExcKomSborn3O/KZDm05mPKyso2Z4lOgeZliTnJPW7z50NcN09BxnmloX5Zob0qn9ZrvgNuYYXrc5XlCEaZzhKTI3l7x0E9Q9LpOTjupuhtEI7wAoE8Q9heazXc82+8lLN+fZvo8Zx8sRYnqMF+Wa4O8MMzn5zCPmZhIUbvZ7fLvJqXccZEbKLcg0w7rkpppj5p406XPlmRtgznGTLmeo9Qo3HQL7TjYj/NZRQwT+FlDU4/sjLXB5cHnKjNcjhKi5ioosAh0mMIMJzj2tIU46DfFtc9lffdPBbeCqF33zo24y7+5Ck5sPqOb/c0oDDE0iqmT7CBkR4rt3PbpsJ5e9UMGIi0KIpucIrj7oizPC9ldh/qxz+P2F/Uvm96fmNmFqhBCi+bN94O/RPoy7L+pf/YZCCCGAFhD4oXRbfoCDaZLrF0KIyrSIwF/WpL991dRJEEKIZqtFBn4hhBCVa7GB33/UTiGEED4tJvDHDyzd9brII8M3CCFERVpM4B/cNbLUvEsCvxBCVKjFBP67LxrAbZN8Y4y7PFLUI4QQFWkxgd8RGMCY3r7BqSTHL4QQFWsxgR8g2G+MnrwiDyezCpowNUII0TzZfpA2f8FOX+CPt9ry73rsMkKDApsoRUII0fy0rBy/o3yAL3DJr3IJIYS/Fhb4y5+O/CqXEEKU1qICf1AFgX9xwpEmSIkQQjRfLSrwF/8oS2xb3+9r/v3zvSSezGmqJAkhRLPTogJ/8SCdQWV+getoRn4TpEYIIZqnFhX4u0WH8puJfXl9VumfozxyOq+JUiSEEM1Pi2rOqZTigSlnlVu+encqvdqH8/raJDq1CeHpacObIHVCCNE8tKgcv7/OkSEl06t2nWDGG+v5ak8q7yUc4cgp8wSwaucJ9p7IBmDviWx6PbCCpNTK6wMKXB7uW7KVE9IxTAhhYy028H9294RK1329NxWAXy5I4JLnzY+z/3dTMgArdxyv9HMrdxxjcUIyc1burlEa0nMKueaVb6WOQQjRrLTYwB8ZYkqxokKd5dZl5rvILXSXWuaxBnVzBKhy2xfLKzKdwcpWHpeVW+gm+XQeSzYms/lwBm9+e6BWaa9MdoGL7cmZuD1eHlu2k5PZ8uQhhKi9Fhv4lVJ8c3883z04mZV3XlBq3bOf7eGcJ1eVWua2frjFUUVQL3CZzmDVDQFx0/wfOP/p1WQVuADTzPT7pPR6/xbw7AUbuWruN8xbm8T8bw/w56U/lqz7YFMyh9MbpxLb7fFy7/tbpVmsEC1Eiw38ALFtwwgLctDNr11/seLcezG31wT1x5fvZHKZ3+wtcnvZnpzJ48t3AuaJYUdKZqXH3XjoNAAvr94PwGtrkrhh3vclvwWstebJFTvZeOgUAGc/9j9+++6mas9nw0Gz/TOf7rHSbG5WxzMLuGfxVm6Y9121+yjr0x3HOe+pLyiqoofz7uPZLNmYzB0LN9d6/0KI5qdFB/5ibYKrbrzU64EVfL7zRMl8kpUz93o1Gw+dZsCfVnLV3G9K1n+4OYUr//FNuf3sOpZFanZhtek5mV3I62sPcPP8DQCcznOxYvsxtNZ8tz8drUv/lkCBy4PWuiTQF0vNLuSplbu44qW1ABzNrH3Rz58/3sHRzALScipPd3HHOBnqWoiWoUU156yMUpWX2xc7kVU68F36/Br2WC1+KrN2XyohzkB6dwhnw4FT3PbuppK6harsPJYFUC6Q937wEwD+OWMUU4Z1ASAjr4iRj33OQ5eXb6a65UgGW45kVHgMr1ez50Q2AUqx6fBpxvRux5q9qQzqEsnYPu1LtnMEmHu/f44/r8hNWJDvPArc5unIP/Brrflqbyrn9W1f4eB4tZV8Oo+Ptx5lYKc27DyaxR0X9q9y245tQiocokMIUb0aBX6l1GXAi0Ag8IbWek6Z9TcDzwIp1qK5Wus3GjCd9fbHywcR7AzgEb9y8apUF/QBZv7rh3LLsgrcFWxZWuIJU1ae7/LQ64EV5dZ/vvMEr69NYuexLEZ2jwZgcUJytft1Bio8Xk1ggGLRhiM89OH2knURwQ5yrArtg3OuKFnuCDQ3xeJ1//vxOLPf3sjyO85naLcok06rWOzwqTzcHi97T+SQcOgUjyz9kcenDmHmuF4l+yt0e0g5nU+fmIhq0+tv9oKNJTdEoCTwr9ufRv+ObYhpEwyYm9L5T6/mulGxPPfTEbU6hhDCqDbLpJQKBF4GpgCDgelKqcEVbPqe1nqk9WpWQR/gVxP6cNO4XnS0AkhFHr6yotNqWC6Pl4UbDle5zQebU9h0OIMCl5fvk0y5fk0qVl0ezS3/3sDEZ1eXCvrgC+xgirYWbzCD1wVarZi2p2RS5Payeo9p6ro12fckUdwCSmt45rM9XP7S2pIb6MNlbqQD//Qpk5/7mmyrYhvgRFYBN8z7rlRx0p7j2WTkFQGw6fBpDqWXrviet2Y/Wmt+/vp6rn91XUn6T+eZ/S7fdrTa76OxVVUvIkRzVpNn5TFAotY6SWtdBCwCpjZushrP4l+P47GpQ/jm/nguGdyJq0d0LVl3bu92vH6TGe6hJkU2dRH/t69ISq1f656qrNmbyqEatO6577/bOJlVUJKWBz/YzoA/rSwJ8n/8cAfvfH+IvSeymf32xpLPzVuTVG5fR07lmZz5Hz8pWZaeU8SuY1lc+vwaXli1l++TTvHWuoN8sv0YCQdPcekLa7j7vS14vJprX1lHbpnK9r9+srtk2cH0PIb++TNufGM94+d8CfiG2/5gUzI7j5onhdO5Rdy3ZCvvrj9U4+/rWGY+Hm/tf5/5m31pDPjTSrZWUtQmRHNWk+jWDfAf2zgZOLeC7a5TSk0A9gJ3a62b5XjIvTqE06tDOADzrCD/yfZjuL2aAZ3aMLRbFKvumUCPduEM+NPKUp/d8sjFjHzscwBG92zLxkOniQ5zkpHnoqyrRnRl2dbyudLk06YzV6gzkPwm/pGYMX/9otyyj/3S/KePdnBnFWXtxS54ZnWF+/locwpJabklxWb/+DKx1Dar96SSlV/+uyv2y7c2lJr/JjGt1HyBy8M9i7cC8MHt53HtK+sAUyw249yeJdvlFbkZ/MhnPHXtMKaP6cGHm5MZ0jWK6FAn4576kt/F9+PeSweW2neh20OgUpU27y1Oy9p9qYywiuOq8pu3NzK0WyS/m1z99ylEY2uo2rFlQC+t9XDgc+CtijZSSs1WSiUopRJSU1Mb6ND1t+AXY3ht5uiSysJ+HdsQ5Ahg6yOXsPXPlwBw4VkdiQ4LKvnMazNHs/OxS9n88MV88YeJ5fb5m4l9ODjnCpbfcT5zf352ufXfP3Rhten6eyVl2D3bh9XovBrCi1/sq9Pn/v753pLWUVUZ//SXla4rLuaqzFkPf1oyXRz0i+1IyWTplhR+9tp3JTecBz/YziNLd3D3e1u56h/fcMgaumP1npOAKbopcHlISs1h4J8+5frXvkNrzcx/reeLXSdK7T/M6stR/FTi9nh5csVOjmVW3Ev70x+P87f/7a3yfIQ4U2qS408BuvvNx+KrxAVAa53uN/sG8ExFO9JazwPmAcTFxdX++bqRnNe3Q4XLo8JMr98v/jCxZIz/B6ecRd+YCDpE+OoK+sZEsPLOC3jt6/3MuW44mw9nMKSrqRgd2i2Kod2iuGRwZ3Yey+InL39r9h3qZPPDF3PlP74hJSOftffFU+TxsvlwBj8cSGdxQjJt/W40xYorZk/nFjH77QRuPq83b357gMd/MpQpL5pmnVeP6Foq5/7WrWOYNd9URC+/43w2Hz5drmy+OusemMx5c3xBuqGeWMr2p2go/s1t1x/w3UAWfGeKgQrdXq5/1fR7CHGaIH7pC2tIycgvKbvffDiDU7lFrN2XxvdJ6ex78nIW/XCYXh3CSwJ/8RPcDwdO8fraAxxIy+ONMqPD7j7uq7Q+mpFP1+jy/UqEOJNU2Tbj5TZQyoEpvrkQE/A3AD/XWv/ot00XrfUxa/oa4H6t9diq9hsXF6cTEhLqmXx7KXB5uPyltcwa14tZ5/UCwOPVHEjLpV9HXyuY3EI33yelM3FADIsTknnlq0ScgQH8Nr4f00bHVrr/07lFfPbjca4dFUuQI4DsAheZ+S66RYfS+8FPaBvmZPMjl/DKV4klncBuGteT31/Yn0K3lz8s3lIulz2iezQzx/Zk2ujYkhZIlw3pzN9+OoLjmQXct2Qr5/Rqx2sVlP3bRfvwIH56Tnf++dX+Kre7fnQs7280ravierYl4dBp+saE869Z53DrWxtK6kve+cW57DmRzbrENEZ0j+bvn5fO6a+88wKWbT3K7fH9CHEElCpOKnB5Sm5EQlREKbVRax1X/ZZV7KO6wG8d6HLgBUxzzvla6yeVUo8BCVrrj5VSTwFXA27gFHCb1rrKkcxaY+BvSusS0+gdE06XqFAy81y88nUi8QM7lmrTD6bV0dGMfLpEhbJ0SwpTR3YrKQK78Lmv6BQZwju/OJeAMmMauTymmORgWh5/+mg7W5MzCXYE8NL0swlyBJCaVch9/90GwG/j++L2av7vkoEMeuRTXB5N7w7hHLCKhs7v16GkDL1TZDD9OkbwbWI6VQkKDGBc3/Y8fd1wxj5Vvu6iMQU7Aur8285XDO/Cry7ow7eJaaRk5POf9Yd5dtpwnIEBjO7Zli5RIVUOI1KW3DhavjMW+BuDBP6WK7fQzcZDp5kwIKbabVMy8glUik6Rwaw/cIpu0aFEBDuYuzqRyWd15Ly+7VFKkZnv4q11B5kwIIYBnSJYteskc7/cx4BObXjq2mG0CTHFclprHvpwB1eN6EJkiJMr//EN3aJDScnIJ9gRwKp7JrJi+7GSEVb7d4xgXzMfg6hHuzCUgt4dwil0edl9PIvRPdsRGerg1xP6suWIKbrrGhXCwfQ8Hp86hLScIk5kFXDj2J4Uuj3069iGx5bt5GfndGdQlzbkFLrJzHcR2zaMiGAHLo+X75PSiQh2cFbnSD7aksJPRnYrNy7V3hPZdG8bRmhQIOv2p7HzaBY3jOlBhNU73uvVFLg9JR0AtdbkFXk4lVvE9pRMLh/WhSOn8jhyKo/z+vmKWL/Zl0aPdmG0jwgixBlIYIBCa12q82VaTiEvrtrH9DE9GNw1stbf4+ncIgICFBHBjpJmzMVpLrIyPH1iIth5NIt24UGEBwdyMC0PjWZ4bDRer+Y/PxxmytDORIU6+XznCSYN7MihU7mc1dmXnuKOjs5a3LBrQwK/ENXIKXQTEewgt9BNgFIlgWzviWzahwfRNiyI75PSGRYbxXf70/F4NRcO6kTiyRw6RARxMruQ3cezuff9rcS2DeX/Lh1IZKiTDzelMHtCH75PSueJFbsAuHRIJ569fgTLth4lr9DDx1uPkngyp0Z1IZEhjhp1/msMzkCFy1NxHAgKDAAFzgCFUqqkP0VUqJNMvxZZIc4AIoId5BV5yHd5aBcWhFKq3FAg/h0JgwIDaBcehNury23XOTKEU7lFRIU5cXm8hDgCOe73Oxhdo0JQSpUMZeLyeHF5NCHOAEBR6PYQFBiAMzAAt9eL26tLWt9FhTqJCHZQ6Pbi8XrJd3nw6qr7ZXSLDiUtp7DSJ7uwoEAc1neUV2T+1qLDnGhtnsK8GsKDA3EEBFDg8jBzXE/uumhApcerigR+IWzM69VkFbhKcqAnsgrpHBVCRl4RYUEOTmYXUODyorUmMtTJj0fNwIA92oVzKreIbxPT6BIVQniwg42HTpcEtPYRQRxMy2XZtmNM6N+BfJeH7m3DSM0x+3cEKI5nFhIeHEhYkIOM/CLCnA62p2QQGeIkKszJ5sMZnN0jmmBHIBqNx2MC7M6jWXRvF0ZEcCC5RR5Sswvp0S4MR6Ci0O3Fa/UcB9CYXt9pOYV0igwh3+UhMsTJ4VO5HEzLI/6sGIrcXhSKxNQcOkUGk5SaS1yvthS4vGTkFdEhIhhnoAmWTkcAa/el4ggI4OwepgltgFI4AlRJgC9yaxwBiiKPlyK31wTbwACcAYojp/Pp0yGc7AI3Lq+XYEcgeUVu0nIKiQxxkppdyJCukSzccIR+MRG0Cw/itNXJcFCXSLxezapdJ7hoUCeUUizbdpRpo2NJPJFT6gnkVG4RHq8mItiBUmZolmBHANkFbpyBAQQ7A4gf2JGLB3eq09+NBH4hhGhlGiLwyyhXQgjRykjgF0KIVkYCvxBCtDIS+IUQopWRwC+EEK2MBH4hhGhlJPALIUQrI4FfCCFamSbrwKWUSgVq/lNJpXUA0qrdyj7kfJo3OZ/mrbWdT0+tdfUDYVWhyQJ/fSilEurbc605kfNp3uR8mjc5n9qToh4hhGhlJPALIUQrY9fAP6+pE9DA5HyaNzmf5k3Op5ZsWcYvhBCi7uya4xdCCFFHtgv8SqnLlFJ7lFKJSqkHmjo9xZRS3ZVSq5VSO5VSPyql7rSWt1NKfa6U2me9t7WWK6XUS9Z5bFNKjfLb1yxr+31KqVl+y0crpbZbn3lJ+f8uXeOdV6BSarNSark131sptd5Kw3tKqSBrebA1n2it7+W3jwet5XuUUpf6LT+j11IpFa2UWqKU2q2U2qWUGmfn66OUutv6W9uhlFqolAqx0/VRSs1XSp1USu3wW9bo16OyYzTS+Txr/b1tU0p9qJSK9ltXq++9Lte2Ulpr27wwP/a+H+gDBAFbgcFNnS4rbV2AUdZ0G2AvMBh4BnjAWv4A8LQ1fTmwElDAWGC9tbwdkGS9t7Wm21rrfrC2VdZnp5yB87oH+A+w3JpfDNxgTb8K3GZN3w68ak3fALxnTQ+2rlMw0Nu6foFNcS2Bt4BfWtNBQLRdrw/QDTgAhPpdl5vtdH2ACcAoYIffska/HpUdo5HO5xLAYU0/7Xc+tf7ea3ttq0xrY/6jNcIf+zjgM7/5B4EHmzpdlaR1KXAxsAfoYi3rAuyxpl8Dpvttv8daPx14zW/5a9ayLsBuv+Wltmukc4gFvgAmA8utf6A0vz/kkusBfAaMs6Yd1naq7DUq3u5MX0sgChMoVZnltrw+mMB/BBPwHNb1udRu1wfoRelA2ejXo7JjNMb5lFl3DfBuRd9ndd97Xf73qkqn3Yp6iv/YiyVby5oV61HrbGA90ElrfcxadRwo/qHNys6lquXJFSxvTC8A9wHFvzDdHsjQWhf/Krh/GkrSba3PtLav7Xk2lt5AKvCmMkVXbyilwrHp9dFapwB/Aw4DxzDf90bse32KnYnrUdkxGtutmCcPqP351OV/r1J2C/zNnlIqAvgvcJfWOst/nTa3ZFs0o1JKXQmc1FpvbOq0NBAH5jH8n1rrs4FczGN+CZtdn7bAVMwNrSsQDlzWpIlqYGfiepypa66U+iPgBt5t7GPVhN0CfwrQ3W8+1lrWLCilnJig/67W+gNr8QmlVBdrfRfgpLW8snOpanlsBcsby3jgaqXUQWARprjnRSBaKeWoIA0l6bbWRwHp1P48G0sykKy1Xm/NL8HcCOx6fS4CDmitU7XWLuADzDWz6/UpdiauR2XHaBRKqZuBK4EZ1o0Gan8+6dT+2lauscogG+OFybUlYXI5xRUfQ5o6XVbaFLAAeKHM8mcpXZH0jDV9BaUrq36wlrfDlEW3tV4HgHbWurKVVZefoXObhK9y931KVzDdbk3/ltIVTIut6SGUrsRKwlRgnfFrCawFBlrTj1rXxpbXBzgX+BEIs473FnCH3a4P5cv4G/16VHaMRjqfy4CdQEyZ7Wr9vdf22laZzsb8R2ukP/jLMS1m9gN/bOr0+KXrfMwj4zZgi/W6HFPW9gWwD1jl90epgJet89gOxPnt61Yg0Xrd4rc8DthhfWYu1VTgNOC5TcIX+PtY/1CJ1h9isLU8xJpPtNb38fv8H60078GvpcuZvpbASCDBukYfWYHCttcH+Auw2zrm21YQsc31ARZi6idcmCeyX5yJ61HZMRrpfBIx5e/FMeHVun7vdbm2lb2k564QQrQydivjF0IIUU8S+IUQopWRwC+EEK2MBH4hhGhlJPALIUQrI4FfCCFaGQn8QgjRykjgF0KIVub/AUKTAjDzA94gAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "complete - train time: 25564s, best epoch: 235, best loss: 0.907677, best accuracy: 80.62%\r"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import paddle\n",
    "import paddle.fluid as fluid\n",
    "from paddle.utils.plot import Ploter\n",
    "import numpy as np\n",
    "import time\n",
    "import math\n",
    "import os\n",
    "\n",
    "epoch_num = 300   # 训练周期，取值一般为[1,300]\n",
    "train_batch = 128 # 训练批次，取值一般为[1,256]\n",
    "valid_batch = 128 # 验证批次，取值一般为[1,256]\n",
    "displays = 100    # 显示迭代\n",
    "\n",
    "start_lr = 0.00001                         # 开始学习率，取值一般为[1e-8,5e-1]\n",
    "based_lr = 0.1                             # 基础学习率，取值一般为[1e-8,5e-1]\n",
    "epoch_iters = math.ceil(50000/train_batch) # 每轮迭代数\n",
    "warmup_iter = 10 * epoch_iters             # 预热迭代数，取值一般为[1,10]\n",
    "\n",
    "momentum = 0.9     # 优化器动量\n",
    "l2_decay = 0.00005 # 正则化系数，取值一般为[1e-5,5e-4]\n",
    "epsilon = 0.05     # 标签平滑率，取值一般为[1e-2,1e-1]\n",
    "\n",
    "checkpoint = True                    # 断点标识\n",
    "model_path = './work/out/ssrnet'     # 模型路径\n",
    "result_txt = './work/out/result.txt' # 结果文件\n",
    "class_num  = 100                     # 类别数量\n",
    "\n",
    "with fluid.dygraph.guard():\n",
    "    # 准备数据\n",
    "    train_reader = paddle.batch(\n",
    "        reader=paddle.reader.shuffle(reader=paddle.dataset.cifar.train100(), buf_size=50000),\n",
    "        batch_size=train_batch)\n",
    "    \n",
    "    valid_reader = paddle.batch(\n",
    "        reader=paddle.dataset.cifar.test100(),\n",
    "        batch_size=valid_batch)\n",
    "    \n",
    "    # 声明模型\n",
    "    model = SSRNet()\n",
    "    \n",
    "    # 优化算法\n",
    "    consine_lr = fluid.layers.cosine_decay(based_lr, epoch_iters, epoch_num) # 余弦衰减策略\n",
    "    decayed_lr = fluid.layers.linear_lr_warmup(consine_lr, warmup_iter, start_lr, based_lr) # 线性预热策略\n",
    "    \n",
    "    optimizer = fluid.optimizer.Momentum(\n",
    "        learning_rate=decayed_lr,                           # 衰减学习策略\n",
    "        momentum=momentum,                                  # 优化动量系数\n",
    "        regularization=fluid.regularizer.L2Decay(l2_decay), # 正则衰减系数\n",
    "        parameter_list=model.parameters())\n",
    "    \n",
    "    # 加载断点\n",
    "    if checkpoint: # 是否加载断点文件\n",
    "        model_dict, optimizer_dict = fluid.load_dygraph(model_path) # 加载断点参数\n",
    "        model.set_dict(model_dict)                                  # 设置权重参数\n",
    "        optimizer.set_dict(optimizer_dict)                          # 设置优化参数\n",
    "    else:          # 否则删除结果文件\n",
    "        if os.path.exists(result_txt): # 如果存在结果文件\n",
    "            os.remove(result_txt)      # 那么删除结果文件\n",
    "    \n",
    "    # 初始训练\n",
    "    avg_train_loss = 0 # 平均训练损失\n",
    "    avg_valid_loss = 0 # 平均验证损失\n",
    "    avg_valid_accu = 0 # 平均验证精度\n",
    "    \n",
    "    iterator = 1                                # 迭代次数\n",
    "    train_prompt = \"Train loss\"                 # 训练标签\n",
    "    valid_prompt = \"Valid loss\"                 # 验证标签\n",
    "    ploter = Ploter(train_prompt, valid_prompt) # 训练图像\n",
    "    \n",
    "    best_epoch = 0           # 最好周期\n",
    "    best_accu = 0            # 最好精度\n",
    "    best_loss = 100.0        # 最好损失\n",
    "    train_time = time.time() # 训练时间\n",
    "    \n",
    "    # 开始训练\n",
    "    for epoch_id in range(epoch_num):\n",
    "        # 训练模型\n",
    "        model.train() # 设置训练\n",
    "        for batch_id, train_data in enumerate(train_reader()):\n",
    "            # 读取数据\n",
    "            image_data = np.array([x[0] for x in train_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取图像数据\n",
    "            image_data = train_augment(image_data)                                                        # 使用数据增强\n",
    "            image = fluid.dygraph.to_variable(image_data)                                                 # 转换数据类型\n",
    "\n",
    "            label_data = np.array([x[1] for x in train_data]).astype(np.int64)                        # 读取标签数据\n",
    "            label = fluid.dygraph.to_variable(label_data)                                             # 转换数据类型\n",
    "            label = fluid.layers.label_smooth(label=fluid.one_hot(label, class_num), epsilon=epsilon) # 使用标签平滑\n",
    "            label.stop_gradient = True                                                                # 停止梯度传播\n",
    "\n",
    "            # 前向传播\n",
    "            infer = model(image)\n",
    "            \n",
    "            # 计算损失\n",
    "            loss = fluid.layers.cross_entropy(infer, label, soft_label=True)\n",
    "            train_loss = fluid.layers.mean(loss)\n",
    "            \n",
    "            # 反向传播\n",
    "            train_loss.backward()\n",
    "            optimizer.minimize(train_loss)\n",
    "            model.clear_gradients()\n",
    "            \n",
    "            # 显示结果\n",
    "            if iterator % displays == 0:\n",
    "                # 显示图像\n",
    "                avg_train_loss = train_loss.numpy()[0]                # 设置训练损失\n",
    "                ploter.append(train_prompt, iterator, avg_train_loss) # 添加训练图像\n",
    "                ploter.plot()                                         # 显示训练图像\n",
    "                \n",
    "                # 打印结果\n",
    "                print(\"iteration: {:6d}, epoch: {:3d}, train loss: {:.6f}, valid loss: {:.6f}, valid accuracy: {:.2%}\".format(\n",
    "                    iterator, epoch_id+1, avg_train_loss, avg_valid_loss, avg_valid_accu))\n",
    "                \n",
    "                # 写入文件\n",
    "                with open(result_txt, 'a') as file:\n",
    "                    file.write(\"iteration: {:6d}, epoch: {:3d}, train loss: {:.6f}, valid loss: {:.6f}, valid accuracy: {:.2%}\\n\".format(\n",
    "                        iterator, epoch_id+1, avg_train_loss, avg_valid_loss, avg_valid_accu))\n",
    "            \n",
    "            # 增加迭代\n",
    "            iterator += 1\n",
    "            \n",
    "        # 验证模型\n",
    "        valid_loss_list = [] # 验证损失列表\n",
    "        valid_accu_list = [] # 验证精度列表\n",
    "        \n",
    "        model.eval()   # 设置验证\n",
    "        for batch_id, valid_data in enumerate(valid_reader()):\n",
    "            # 读取数据\n",
    "            image_data = np.array([x[0] for x in valid_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取图像数据\n",
    "            image_data = valid_augment(image_data)                                                        # 使用图像增强\n",
    "            image = fluid.dygraph.to_variable(image_data)                                                 # 转换数据类型\n",
    "            \n",
    "            label_data = np.array([x[1] for x in valid_data]).reshape((-1, 1)).astype(np.int64) # 读取标签数据\n",
    "            label = fluid.dygraph.to_variable(label_data)                                       # 转换数据类型\n",
    "            label.stop_gradient = True                                                          # 停止梯度传播\n",
    "            \n",
    "            # 前向传播\n",
    "            infer = model(image)\n",
    "            \n",
    "            # 计算精度\n",
    "            valid_accu = fluid.layers.accuracy(infer,label)\n",
    "            \n",
    "            valid_accu_list.append(valid_accu.numpy())\n",
    "            \n",
    "            # 计算损失\n",
    "            loss = fluid.layers.cross_entropy(infer, label)\n",
    "            valid_loss = fluid.layers.mean(loss)\n",
    "            \n",
    "            valid_loss_list.append(valid_loss.numpy())\n",
    "        \n",
    "        # 设置结果\n",
    "        avg_valid_accu = np.mean(valid_accu_list)             # 设置验证精度\n",
    "        \n",
    "        avg_valid_loss = np.mean(valid_loss_list)             # 设置验证损失\n",
    "        ploter.append(valid_prompt, iterator, avg_valid_loss) # 添加训练图像\n",
    "        \n",
    "        # 保存模型\n",
    "        fluid.save_dygraph(model.state_dict(), model_path)     # 保存权重参数\n",
    "        fluid.save_dygraph(optimizer.state_dict(), model_path) # 保存优化参数\n",
    "        \n",
    "        if avg_valid_loss < best_loss:\n",
    "            fluid.save_dygraph(model.state_dict(), model_path + '-best') # 保存权重\n",
    "            \n",
    "            best_epoch = epoch_id + 1                                    # 更新迭代\n",
    "            best_accu = avg_valid_accu                                   # 更新精度\n",
    "            best_loss = avg_valid_loss                                   # 更新损失\n",
    "    \n",
    "    # 显示结果\n",
    "    train_time = time.time() - train_time # 设置训练时间\n",
    "    print('complete - train time: {:.0f}s, best epoch: {:3d}, best loss: {:.6f}, best accuracy: {:.2%}'.format(\n",
    "        train_time, best_epoch, best_loss, best_accu))\n",
    "    \n",
    "    # 写入文件\n",
    "    with open(result_txt, 'a') as file:\n",
    "        file.write('complete - train time: {:.0f}s, best epoch: {:3d}, best loss: {:.6f}, best accuracy: {:.2%}\\n'.format(\n",
    "            train_time, best_epoch, best_loss, best_accu))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 模型预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "infer time: 0.014444s, infer value: cattle\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMgAAADFCAYAAAARxr1AAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAGylJREFUeJztnWlsXOd1ht9zZ+Um7pQoURIteZUVW04c17HjVFmcOGkAJ0VhJGgDA3WWAgnaoPljuECbAv2RAk2CoghSJKhrB0jjpHFSu47T2HGdOnYTWbIta7MsibQW7hSX4XAWznK//phhyuF7eDniSBQpnwcQxDm8c+937/DMved857yfOOdgGIaOd7kHYBhrGXMQwwjAHMQwAjAHMYwAzEEMIwBzEMMIwBzEMAIwBzGMAGpyEBG5R0TeFJFTIvLgxRqUYawVZKUz6SISAnACwN0ABgDsB/Bp59yxpd7T0dHhent7V3Q8Yzn4c8zPzZEtlU6TrbFpg7rHcDhc+7BWgK/YisWCuu3cXJZsoTB/7+dylduNjYwjMZ2U5cZSyxW4DcAp51w/AIjIYwDuBbCkg/T29uLAgQM1HNJYkiI7w8jZPrLte/lVst31oXvUXba1d9Q+rmUoKrZ0ka3J2Un1/f19b5Cttb2BbGfPnqx4/eefe6iq8dXyiLUFwLkFrwfKtgpE5PMickBEDoyPj9dwOMNYfS55kO6c+45z7lbn3K2dnZ2X+nCGcVGp5RFrEMDWBa97yrYLwqqJLxxfeR6X/BTZkmP9ZHv+yZ/wdkl+jgeAP/nsZ9mofF6+r3yGylevAz/y55X3Dg2fJdvk9IA6xuFzR8nWf/I82RIzlddnLptS97eYWu4g+wFcIyJXiUgUwKcAPFnD/gxjzbHiO4hzriAiXwLwCwAhAA8759idDWMdU1Mezzn3NICnL9JYDGPNYTPphhHA5ZkJWgaRZedv3jZoKQxPlNmDYpLfm+G0eoOfI9vE8Ih67NGRUbKFhL9Tm1uayRaJRsjmK0G6czwtGOa3Il/MqGNs39hOttFxDtKH+4Yq95fPq/tbjN1BDCMAcxDDCMAcxDACMAcxjADWZJC+WmhVo87nor/CFAd9mcQsvzfKRXIbtmzWD64Eu6IErJ7Ps+Yzw+fIdvrIb8n21hvHeX9eVNkfz1wDwK+efpxsrZu3ku2OO+/iN4e5QnhiOkG2uVlOEGSzY2RzBU5CAMDYJFcLTE3z5+X8xde7ukSQ3UEMIwBzEMMIwBzEMAIwBzGMAN7WQTp8npE+f4oD27FXXiRbepIDzpEcf99ce9de9dDX3Hwr2bwIfxyHjx4m22vPP0+2pBK4z4zxTHgkHCNbdmKIbADw/M/OkO2G3/8I2d7zvg/yPud4xn5qjPfXv59L+UaHuBOyffs2dYxpn8vW82m+jlGvq+K1VPmnb3cQwwjAHMQwAjAHMYwAzEEMIwBzEMMIoKYsloicBpBESd6o4Jzj1MwaxmW5rGTiTc6gYHqGTG0hRcjM48xN/wvPqscOOy51iG/mTM33fvyfZDt64CDZdrRymUubx2NsUDJlxZDSgAGg/wRnt1488WOydffcSLa7bruBbOPH/5dsrz/zU7LNTbMARWpwlzrG+l3vYlsd63k1XdVa8Toaq04+4WKked/vnOPiF8O4ArBHLMMIoFYHcQCeEZFXROTz2gamrGisZ2p1kPc6594J4KMAvigi71u8gSkrGuuZWmV/Bsv/j4nIT1EStH7hgnZyGfUZvCj3RjR2cf/G+MBbZMuOs9JfQ5T7OWay+gke/61SvtK6nWzPPPMSb5fk3ogmr5ttrXGypeY4cD9+VhdtGEmxZMTABAfQ33/kX3m7g11kS59j4fKGIpeKxOq4HGYuxar0ALC9kQNyb+PVZMtK5Wcd0pQhFFZ8BxGRBhFpmv8ZwIcBHFnp/gxjLVLLHWQjgJ+WJXrCAP7NOfdfF2VUhrFGqEV6tB/AzRdxLIax5rA0r2EEcPn7QTTpwGoD96VWTqjy/U5ZYmzTO/immJ+dJlvf2TfJlp7kNHYuVqce+8QJXhkp1cjqgeE8n+TMBK+2lFBWVYpv58B9ZoqD7ENn9CB9PMdJjKZmVlE8e+p1su2b5CUVrungwDga4fObnmNbU5d+HYeHuA9mQ30bH6dtkQKjVLfsht1BDCMAcxDDCMAcxDACMAcxjAAue5CuxUpKJfgS772A9Q2VJRVEWR8vEuPZ5y233cn7UyZih1/lWe8eRYkQACbOs2DEoX2vka0uzIF7RxMHz3vv4jH+3s1cIv5P3/oW2ZIZLtMH9GuhKRymlVnu2FZelsB3HLiPjnErQbh1I9mkQS9Tev0otyckXmHhje4dOypep2b4uBp2BzGMAMxBDCMAcxDDCMAcxDACWPUgffGi85qH+krwnc1x/3hUmQkH9HX0PG16XQncC8r0fN8kdxRPKQHs3LW7yXbju+5Qx5g/y7PhP/rZL3m7DJeDf/KevWT7w49/mGwnT/HSAGMpTg7kXEgdY8TxttEwb9sU52vR0MJBdSLP59KwkWf7XR0vnTAwri9/UMxwEiOnaAg8/2RloXlymqsjNOwOYhgBmIMYRgDmIIYRgDmIYQSwbJAuIg8D+DiAMefc7rKtDcAPAfQCOA3gPucc11EvwncOc/nKWdu40hc+k+b1/17av49sGxob1ePccuNNZGuqqydbscj92YPjLJb2qxc5eH7rLK/rN6fMSMc296pjLCR5VnnsDC8PMJvka7Gzl2fnw+CAejrBwWrO5yC7UNRWawT8NAfGnuMSglCcP8OJSf5zGB3jZEedsq5jQzMnZBpbeDsAaFKSBnVhTrRs7WipeN13Tl/yYTHV3EEeAXDPItuDAJ5zzl0D4Lnya8O44ljWQZxzLwBYnJO8F8Cj5Z8fBfCJizwuw1gTrDQG2eicGy7/PIKSgIPKQuG48yYcZ6wzag7SnXMOSze/VgjHdZhwnLHOWOlM+qiIdDvnhkWkGwCv/K4gAsiioGpmloPQ/QdfJdvZ4UGyxaIsMAYAnW0sJnZd706yJWYmyHbwIAu6DZ8+RraRsxxwjk3xuRw8zIrmAHBbz/Vk27GJv0Cm2ri/urmDZ5/PDXFf+fAwB6KpJAfPLY16v3dqloP0mSmuANjR1UO2xjj/aaXrFGX5AidKiikeY9HTy9NzrVxWjzAnLJqbK88xHKru3rDSO8iTAO4v/3w/gCdWuB/DWNMs6yAi8gMAvwFwnYgMiMgDAL4G4G4ROQngQ+XXhnHFsewjlnPu00v8itf+NYwrDJtJN4wAVrXc3flAca4ygHpp38u03StHD5Ft5/UcCA6dS6jH+Y+nniPbxz+WJ1vfaRZv6zvHSu5eiMu5J5VZ4cGB02SLF9+tjvEdvb1k+7M//QzZtNnwnS0s3jY0xEmMk4c5uZCc4FR7c7sS6AIoFpQydmXSfUtrE9mcshyd+PzmkMcJ0FBIaUPI8+cHAGlF1C8U5pn9ol+ZDHDQqwcWY3cQwwjAHMQwAjAHMYwAzEEMIwBzEMMIYFWzWEW/iORsZebpv1/gXov2zVwqMpfl/okz/bpsvyiZkZcPserhESVbJsolCWmXKcw9C3s/uIdsXa1cKgIAhTRneXZfdx3ZPGW5goFfcJau7jxnc+5u4nUCN13LvTIHxofJBgDH67j3o7eHy1w6lbKSbJbLVLS+E9/n7JS2fmAsrJfD5JSelajS++NF9LKk5bA7iGEEYA5iGAGYgxhGAOYghhHAqgbp4gkiDZXBUnMbCy8MDrKk/aHXeQn2M6e4/wIAuns4oGvfxCUbvs+9CFOTvM+IEvT37lAC4M1ccpGZ00skclkO0ouK6EPmNJeQpE9zUJ1IcDBfp5SkvHsbl+x0x3jcALBhgvtJwq0snuBH+Dq6IgfaogTkxTwnX0SLpxWxidI+ufejMMf7jHqL329rFBpGzZiDGEYA5iCGEYA5iGEEsFJlxa8C+ByA+eaCh5xzTy+3r1Q6i32vVfZgFBXp/VCIh/VWP/dpDA7qQXpjK4sfFIutZEsmeW09LUi/Sglsuzo5SB8YOEG21rAusx+5kRMJ4QRL+Z87eJRsR2d4GYGfHePtEj4Hqy1xnmX+8HW3qmO8I8oKjudGT5Mt1MwBeaGeezrySvDsfE5MOJ8/fy3wBoBiUZmJd8qM/eKlMqpc33KlyooA8E3n3J7yv2WdwzDWIytVVjSMtwW1xCBfEpFDIvKwiPDzS5mFyoqJKlf1MYy1wkod5NsAdgLYA2AYwNeX2nChsmJzS8tSmxnGmmRFM+nOudH5n0XkuwCequZ9c7kM3jp9uHIAilR9VzuXu4vSZB+v02dXP/SBj5Dt+l07yFacYwXHrjZFOr97G9k623j2ecdWLlff1rlZHaMm7JcY4uUPJmZYtLIfHJg23cRl7IUMVw9MT7LQxRNnWNwBAG7s4tL2q7Rp7hFOLmSaeYbbFbhFoFDgIN3Pc9BfXGLmO53lpEq8QVlbsW7xuC/hTHpZbnSeTwLgOhDDuAKoJs37AwB7AXSIyACAvwGwV0T2oOSGpwF84RKO0TAuGytVVvyXSzAWw1hz2Ey6YQSwquXu0aiPzb2VAV1rB8/s5vMcuH3kD1ihcGKCg0MACMc5SMvleJ+33HIj2bIpDiSHlKUO9tzA793Zu51s0+d12f7hES4lnzw3QDbvat7nXe/fS7asx4HtzCxfnwJfGhx98zAbAZx98xTZukIc3G7wOIHifN7OE95OlJYDpwyysERMnVMUF8NFRZmxUHktnDLbrmF3EMMIwBzEMAIwBzGMAMxBDCOAVQ3Sk6kEXtj/8wpbQQnItvVyufqeO3aR7UyfLhznCQe7k7O8HqFf5Jn4ZIKDxokZDrRffp1npI/38ez64KAepMeV8u3rY7wMgdfAM/EjSln8S/t/TbaCEodGYlxmn5jVVx/ORfj6JOKcDAiHeLs0+PyKSv94aHEZOoCwYssraxkCgCf8HR8K83iyc5XJF19JIqj7r2orw3ibYg5iGAGYgxhGAOYghhHAqgbpsXgYO6+uDETzSrlz1yZtVphLwZMpvdExHOaS7HyR19tLJDmAzitTtm09nDSIxDhID8W5V3z79fp3kF9ke1OYg/xfv8jrKB49yWJyTU3cayOeorqe40qBiWn9OvqO3+8UtfqkokCfyXG/vwjPcEejvJ6gZsso6v4AEI7y34rn8bUtUILAgnTDqBlzEMMIwBzEMAIwBzGMAKrpKNwK4HsANqIU2XzHOfePItIG4IcAelHqKrzPOcfR2gIa6uK4dU9l3/asUpJ97NjrZJuc5l1fv2u3epymxg3amZBlbJwDtXyOt0tO8zJfMymefW5v26TYdMGX2Sx/N8VDHGiH6zlwL+b5mkWFVfLrG1mJ3VMSAdPj59QxtnT3kq01yn8yiUkWzPOFky+xGAffnhK4Fwpcwq61QABAg7LcWlEpIWhorFS69zxddJDGV8U2BQBfcc7tAnA7gC+KyC4ADwJ4zjl3DYDnyq8N44qiGuG4Yefcq+WfkwDeALAFwL0AHi1v9iiAT1yqQRrG5eKCYhAR6QVwC4B9ADY65+ZXchlB6RFMe8/vhOOmJ3mewDDWMlU7iIg0AngcwJedcxUzbM45hyVmXhYKx7W08TOxYaxlqnIQEYmg5Bzfd879pGwendfHKv/PCmeGsc6pJoslKMn8vOGc+8aCXz0J4H4AXyv//8Ry+yr6BSRmKwUQPHBZyEyCsxDHj3PW6FT//6jH6dnGyow37dlJtm3KdnUeZ8CcIgJQVPpYohHutRCuhAAA1Gf4httdz2O8ZQ9naTqaudzjpRdeIltiirWQtf6b8UH9u801cH9K8VoeI5TrowlnxMJ8MTIpLknxi9z7EY3r3+UhRXEzl1GUKRZXGlVXaVJVLdadAD4D4LCIHCzbHkLJMX4kIg8AOAPgvuoOaRjrh2qE416ENolQ4oMXdziGsbawmXTDCMAcxDACWNV+EE+A+milTzqfg6w7b38X2XbuvIFs/WdOq8cZG2fRhukJRSY/wgmC0QwnA1paOHBvauKSDRdRylRmuG8EANoaeN3Dzi7uO0lu5cB//29+Q7aJaVZ/9JVrqyHcKgMAaGvjX7Rt4XKYlPI1G1HEFKLachXC0XImw6U0ztOj6oKizKiddnrRPqu9NnYHMYwAzEEMIwBzEMMIwBzEMAJY1SAd4uCFKoMqL6LI6SsL03ds2kK2G3br6/9lsxzk+Yqq3/D5YbKNJTjYHZsZJdumbg6om5s5qPWX6DuYzfN300T2ZbINTrKwxJFjPGs+l+Vxx+NLRN+LaGjWA+CtbUrvR/Is2bwWPk5LhKsUfHBPhyqw4Pizmk3q1zHkKYG/sgAkTfYvNbO3CLuDGEYA5iCGEYA5iGEEYA5iGAGsapCezc3hxFDlunfNLTwjHctxYLohzs1WrcpsNgDEldJoDywY0NXK5dyRMM9czyR5dj3kOMqbmeby8tFxXnYBABKjrBR5qoPFKnqabyHbH9/3PrId3s/v1dZlbGllEYk5pUwfANw0VwEcOXaIbL2dLBjR3sAl+QVFCXNCKW3fEOHZeqeIOwDAbIIFNeL1/LdSv6FyjJ6nVzgsxu4ghhGAOYhhBGAOYhgBmIMYRgC1KCt+FcDnAMxHsA85554O2lfRL2J6tjIAzxZY1j6mLC2Qb2omW3J2KXU8LmWur+PArbG+m2zxKAecnc1c7p5X1A215RQGTg2pIwwrSxMcGmWFw3PKZPi1US79b1Ouz+YurjTwlPLwbL0eAE9EuFd9CzgxUhfmY9c1KIqQaT6ZfJFVFHNZXqIhn9PXKEwrypyxGB+7tbVS9TIUrk5jpJos1ryy4qsi0gTgFRF5tvy7bzrn/qGqIxnGOqSanvRhAMPln5MiMq+saBhXPLUoKwLAl0TkkIg8LCKqSvNCZcVUgm+nhrGWqUVZ8dsAdgLYg9Id5uva+xYqKzYoVbqGsZapaiZdU1Z0zo0u+P13ATy13H6ikTh6Nl5dYSsoUvWeUq6cyfCs8Ni0rvWrzXxv3c5LE6QVOf5skvfZ2KjMFLcrs/ARFnnbsV1f/6++kQPW/j4u3Y6FlSUMuvmatWzkRMLsLM8yh4ocAO+88WqyAYB/nMvO8wUedzymLEHg8RjbG3m7cITPeeo8Vx+Iz/oBAJDO8FNJOMbbeqHKP3VtvUSNZe8gSykrzsuOlvkkgCNVHdEw1hG1KCt+WkT2oJT6PQ3gC5dkhIZxGalFWTFwzsMwrgRsJt0wAljVcnfnisgVKoPgWIxLrRvquNy5WOCZ1HSClcEBoKGeA79ingPyyTSvexhX1uDTFNp9jwPYdI5n9rs2aeslAvX1HLBu2qSUiBf5OHM+zx63t3EPeCbB28UjnHAI1fN2ABAf54C8boTPx/M58C+Ckx1eiD/rugb+rNMpTshE4rrQW9FxQsYXDtwzhcoqB1/pe9ewO4hhBGAOYhgBmIMYRgDmIIYRwKoG6UW/iFS6cma54LNoWXKWhdpCwkGtCAe1ANDcxPZ0mvcZUZYEkzAH+KksB9/JIS5t12auoZwfADifM+chRR3e95VgV8m6F9PcIhAOcWCbSnNAnczpffPSzLP40sABfeo8B9V5JQgugI89l+HrmHccZA8MD6pjHBnjSoXOzZwMcOnKJE9RKfvXsDuIYQRgDmIYAZiDGEYA5iCGEYA5iGEEsLqlJr6HfKayVCE1y83z2kLyuRxnaaJKuQcATL3FJSgzKc6C7H7HtWRLjHBGxxO+TOoad0pm6q0+PfsSi3JWrqWNsy/Nrfwd1tzCZTPIcbYrrpSzJGZZJCOd5iwUALiMIvAQ4cxfHlx+4ucVgYYQfy75MGex0nnOTPWfZUELAEgm+G+gpYf7QQpe5Tk66NnFxdgdxDACMAcxjADMQQwjgGpabuMi8rKIvC4iR0Xkb8v2q0Rkn4icEpEfiojyYGwY65tqgvQ5AB9wzs2WxRteFJGfA/hLlITjHhORfwbwAEpKJ0uSz/kYGqgsx/CVwDYa4RKHwWEOnnM5XRAhrCxh0NLKgeTgsFLS4vF4PPD+6pW+Ck2VMRzTpY6OnzpOts1ZHmP4PJdnRCKcIGisZzXBhgZWPMxkOEgPRZfqteAAujHew9t5SsNMhktSpgp8vaWLy3MmZ/mzTs7qY8w6/o7vfScrT+6+ZXvF64OHn1H3t5hl7yCuxHwxUqT8zwH4AIAfl+2PAvhEVUc0jHVEVTGIiITKgg1jAJ4F0Adg2jk3nwccwBJqiwuF49KzejrRMNYqVTmIc67onNsDoAfAbQCur/YAC4Xj6hstTDHWFxeUxXLOTQN4HsB7ALSI/G4GrQeAPiNmGOuYapY/6ASQd85Ni0gdgLsB/D1KjvJHAB4DcD+AJ5bb19xcHn19w5X7V5YqaGpk28wU+3IyqT+y7drNsv+921kJcWDoNB+7iSWGXZ5nXesbOKCOKYF77zZdwa+tjWeas1meaZ5W1glMTClqlG3Kun557m3xPD5uInVeHWOuyLPz0wkWSdiQ4hn7mBI8Zz3eXyzK2yWSSh9LSv8ub97CTyXxTkW0o7EyOeGUXhmNarJY3QAeFZEQSnecHznnnhKRYwAeE5G/A/AaSuqLhnFFUY1w3CGUFN0X2/tRikcM44rFZtINIwBzEMMIQJyrruz3ohxMZBzAGQAdAPTIcP1h57I2We5ctjvnOpfbyao6yO8OKnLAOXfrqh/4EmDnsja5WOdij1iGEYA5iGEEcLkc5DuX6biXAjuXtclFOZfLEoMYxnrBHrEMIwBzEMMIYNUdRETuEZE3y626D6728WtBRB4WkTERObLA1iYiz4rIyfL/XO24BhGRrSLyvIgcK7dS/0XZvu7O51K2ha+qg5QLHr8F4KMAdqG0Uu6u1RxDjTwC4J5FtgcBPOecuwbAc+XX64ECgK8453YBuB3AF8ufxXo8n/m28JsB7AFwj4jcjlLV+Tedc1cDmEKpLfyCWO07yG0ATjnn+p1zOZRK5e9d5TGsGOfcCwAWN8Lfi1LLMbCOWo+dc8POuVfLPycBvIFSV+i6O59L2Ra+2g6yBcBCibwlW3XXERudc/NNLiMANl7OwawEEelFqWJ7H9bp+dTSFh6EBekXEVfKma+rvLmINAJ4HMCXnauUMVlP51NLW3gQq+0ggwC2Lnh9JbTqjopINwCU/2ex4TVKWcbpcQDfd879pGxet+cDXPy28NV2kP0ArilnF6IAPgXgyVUew8XmSZRajoEqW4/XAiIiKHWBvuGc+8aCX6278xGRThFpKf883xb+Bv6/LRxY6bk451b1H4CPATiB0jPiX6328Wsc+w8ADAPIo/RM+wCAdpSyPScB/BJA2+UeZ5Xn8l6UHp8OAThY/vex9Xg+AG5Cqe37EIAjAP66bN8B4GUApwD8O4DYhe7bSk0MIwAL0g0jAHMQwwjAHMQwAjAHMYwAzEEMIwBzEMMIwBzEMAL4P/reBAlsXKWPAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 216x216 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import paddle.fluid as fluid\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "image_path = './work/out/img.png' # 图片路径\n",
    "model_path = './work/out/ssrnet-best' # 模型路径\n",
    "\n",
    "# 加载图像\n",
    "def load_image(image_path):\n",
    "    \"\"\"\n",
    "    功能:\n",
    "        读取图像并转换到输入格式\n",
    "    输入:\n",
    "        image_path - 输入图像路径\n",
    "    输出:\n",
    "        image - 输出图像\n",
    "    \"\"\"\n",
    "    # 读取图像\n",
    "    image = Image.open(image_path) # 打开图像文件\n",
    "    \n",
    "    # 转换格式\n",
    "    image = image.resize((32, 32), Image.ANTIALIAS) # 调整图像大小\n",
    "    image = np.array(image, dtype=np.float32) # 转换数据格式，数据类型转换为float32\n",
    "\n",
    "    # 减去均值\n",
    "    mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, -1)) # cifar数据集通道平均值\n",
    "    stdv = np.array([0.2471, 0.2435, 0.2616]).reshape((1, 1, -1)) # cifar数据集通道标准差\n",
    "    \n",
    "    image = (image/255.0 - mean) / stdv # 对图像进行归一化\n",
    "    image = image.transpose((2, 0, 1)).astype(np.float32) # 数据格式从HWC转换为CHW，数据类型转换为float32\n",
    "    \n",
    "    # 增加维度\n",
    "    image = np.expand_dims(image, axis=0) # 增加数据维度\n",
    "    \n",
    "    return image\n",
    "\n",
    "# 预测图像\n",
    "with fluid.dygraph.guard():\n",
    "    # 读取图像\n",
    "    image = load_image(image_path)\n",
    "    image = fluid.dygraph.to_variable(image)\n",
    "    \n",
    "    # 加载模型\n",
    "    model = SSRNet()                               # 加载模型\n",
    "    model_dict, _ = fluid.load_dygraph(model_path) # 加载权重\n",
    "    model.set_dict(model_dict)                     # 设置权重\n",
    "    model.eval()                                   # 设置验证\n",
    "    \n",
    "    # 前向传播\n",
    "    infer_time = time.time()              # 推断开始时间\n",
    "    infer = model(image)\n",
    "    infer_time = time.time() - infer_time # 推断结束时间\n",
    "    \n",
    "    # 显示结果\n",
    "    vlist = ['beaver', 'dolphin', 'otter', 'seal', 'whale',\n",
    "             'aquarium fish', 'flatfish', 'ray', 'shark', 'trout',\n",
    "             'orchids', 'poppies', 'roses', 'sunflowers', 'tulips',\n",
    "             'bottles', 'bowls', 'cans', 'cups', 'plates',\n",
    "             'apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers',\n",
    "             'clock', 'keyboard', 'lamp', 'telephone', 'television',\n",
    "             'bed', 'chair', 'couch', 'table', 'wardrobe',\n",
    "             'bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach',\n",
    "             'bear', 'leopard', 'lion', 'tiger', 'wolf',\n",
    "             'bridge', 'castle', 'house', 'road', 'skyscraper',\n",
    "             'cloud', 'forest', 'mountain', 'plain', 'sea',\n",
    "             'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo',\n",
    "             'fox', 'porcupine', 'possum', 'raccoon', 'skunk',\n",
    "             'crab', 'lobster', 'snail', 'spider', 'worm',\n",
    "             'baby', 'boy', 'girl', 'man', 'woman',\n",
    "             'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle',\n",
    "             'hamster', 'mouse', 'rabbit', 'shrew', 'squirrel',\n",
    "             'maple', 'oak', 'palm', 'pine', 'willow',\n",
    "             'bicycle', 'bus', 'motorcycle', 'pickup truck', 'train',\n",
    "             'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor'] # 标签名称列表\n",
    "    vlist.sort() # 字母上升排序\n",
    "    print('infer time: {:f}s, infer value: {}'.format(infer_time, vlist[np.argmax(infer.numpy())]) )\n",
    "    \n",
    "    image = Image.open(image_path) # 打开图像文件\n",
    "    plt.figure(figsize=(3, 3))     # 设置显示大小\n",
    "    plt.imshow(image)              # 设置显示图像\n",
    "    plt.show()                     # 显示图像文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PaddlePaddle 1.8.4 (Python 3.5)",
   "language": "python",
   "name": "py35-paddle1.2.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
